Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] GCNConv, ChebConv, GNNChain #462

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib
using LuxCore: LuxCore, AbstractExplicitLayer
using Lux: glorot_uniform, zeros32
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, glorot_uniform, zeros32
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
@reexport using GNNGraphs

include("layers/basic.jl")
export GNNLayer,
GNNContainerLayer,
GNNChain

include("layers/conv.jl")
export GraphConv
export GCNConv,
ChebConv,
GraphConv

end #module

61 changes: 61 additions & 0 deletions GNNLux/src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
abstract type GNNLayer <: AbstractExplicitLayer end

An abstract type from which graph neural network layers are derived.
It is Derived from Lux's `AbstractExplicitLayer` type.

See also [`GNNChain`](@ref GNNLux.GNNChain).
"""
abstract type GNNLayer <: AbstractExplicitLayer end

abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end

@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
layers <: NamedTuple
end

GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...)

function GNNChain(; kw...)
:layers in Base.keys(kw) &&
throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
nt = NamedTuple{keys(kw)}(values(kw))
nt = map(_wrapforchain, nt)
return GNNChain(nt)
end

_wrapforchain(l::AbstractExplicitLayer) = l
_wrapforchain(l) = Lux.WrappedFunction(l)

Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers))
Base.getindex(c::GNNChain, i::Int) = c.layers[i]
Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))

function Base.getproperty(c::GNNChain, name::Symbol)
hasfield(typeof(c), name) && return getfield(c, name)
layers = getfield(c, :layers)
hasfield(typeof(layers), name) && return getfield(layers, name)
throw(ArgumentError("$(typeof(c)) has no field or layer $name"))
end

Base.length(c::GNNChain) = length(c.layers)
Base.lastindex(c::GNNChain) = lastindex(c.layers)
Base.firstindex(c::GNNChain) = firstindex(c.layers)

LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end])

(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st)

function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times
newst = (;)
for (name, l) in pairs(layers)
x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name))
newst = merge(newst, (; name => s′))
end
return x, newst
end

_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;)
_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
152 changes: 115 additions & 37 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,132 @@
# Missing Layers

# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
# | :-------- | :---: |:---: |:---: | :---: | :---: |
# | [`AGNNConv`](@ref) | | | ✓ | | |
# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`EGNNConv`](@ref) | | | ✓ | | |
# | [`EdgeConv`](@ref) | | | | ✓ | |
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`GMMConv`](@ref) | | | ✓ | | |
# | [`MEGNetConv`](@ref) | | | ✓ | | |
# | [`NNConv`](@ref) | | | ✓ | | |
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
# | [`TransformerConv`](@ref) | | | ✓ | | |


@concrete struct GCNConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
add_self_loops::Bool
use_edge_weight::Bool
init_weight
init_bias
σ
end

@doc raw"""
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
function GCNConv(ch::Pair{Int, Int}, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
use_edge_weight::Bool = false,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
end

Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

Performs:
```math
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
```
LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::GCNConv) = 0
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)

where the aggregation type is selected by `aggr`.
function Base.show(io::IO, l::GCNConv)
print(io, "GCNConv(", l.in_dims, " => ", l.out_dims)
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
l.add_self_loops || print(io, ", add_self_loops=false")
!l.use_edge_weight || print(io, ", use_edge_weight=true")
print(io, ")")
end

# Arguments
# TODO norm_fn should be keyword argument only
(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) =
GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st

- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
@concrete struct ChebConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
k::Int
init_weight
init_bias
σ
end

# Examples
function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
end

```julia
# create data
s = [1,1,2,3]
t = [2,3,1,1]
in_channel = 3
out_channel = 5
g = GNNGraph(s, t)
x = randn(Float32, 3, g.num_nodes)
function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims, l.k)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

LuxCore.parameterlength(l::ChebConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims :
l.in_dims * l.out_dims * l.k
LuxCore.statelength(d::ChebConv) = 0
LuxCore.outputsize(d::ChebConv) = (d.out_dims,)

function Base.show(io::IO, l::ChebConv)
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K)
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

# create layer
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st

# forward pass
y = l(g, x)
```
"""
@concrete struct GraphConv <: AbstractExplicitLayer
@concrete struct GraphConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight::Function
init_bias::Function
init_weight
init_bias
σ
aggr
end


function GraphConv(ch::Pair{Int, Int}, σ = identity;
aggr = +,
init_weight = glorot_uniform,
Expand All @@ -65,10 +143,10 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight1, weight2, bias)
else
bias = false
return (; weight1, weight2)
end
return (; weight1, weight2, bias)
end

function LuxCore.parameterlength(l::GraphConv)
Expand All @@ -90,4 +168,4 @@ function Base.show(io::IO, l::GraphConv)
print(io, ")")
end

(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
24 changes: 24 additions & 0 deletions GNNLux/test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@testitem "layers/basic" setup=[SharedTestSetup] begin
rng = StableRNG(17)
g = rand_graph(10, 40, seed=17)
x = randn(rng, Float32, 3, 10)

@testset "GNNLayer" begin
@test GNNLayer <: LuxCore.AbstractExplicitLayer
end

@testset "GNNChain" begin
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
@test GNNChain <: GNNContainerLayer
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
ps = LuxCore.initialparameters(rng, c)
st = LuxCore.initialstates(rng, c)
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(c) == LuxCore.statelength(st)
y, st′ = c(g, x, ps, st)
@test LuxCore.outputsize(c) == (3,)
@test size(y) == (3, 10)
loss = (x, ps) -> sum(first(c(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end
end
35 changes: 33 additions & 2 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
g = rand_graph(10, 30, seed=1234)
g = rand_graph(10, 40, seed=1234)
x = randn(rng, Float32, 3, 10)

@testset "GCNConv" begin
l = GCNConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
Expand All @@ -14,6 +45,6 @@
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end
end
Loading
Loading