-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* create GNNLux * create GNNLux.jl * fix ci
- Loading branch information
1 parent
cafc1bc
commit 79515e9
Showing
10 changed files
with
272 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
name: GNNLux | ||
on: | ||
pull_request: | ||
branches: | ||
- master | ||
push: | ||
branches: | ||
- master | ||
jobs: | ||
test: | ||
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
version: | ||
- '1.10' # Replace this with the minimum Julia version that your package supports. | ||
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia. | ||
# - 'pre' | ||
os: | ||
- ubuntu-latest | ||
arch: | ||
- x64 | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: julia-actions/setup-julia@v2 | ||
with: | ||
version: ${{ matrix.version }} | ||
arch: ${{ matrix.arch }} | ||
- uses: julia-actions/cache@v2 | ||
- uses: julia-actions/julia-buildpkg@v1 | ||
- name: Install Julia dependencies and run tests | ||
shell: julia --project=monorepo {0} | ||
run: | | ||
using Pkg | ||
# dev mono repo versions | ||
pkg"registry up" | ||
Pkg.update() | ||
pkg"dev ./GNNGraphs ./GNNlib ./GNNLux" | ||
Pkg.test("GNNLux"; coverage=true) | ||
- uses: julia-actions/julia-processcoverage@v1 | ||
with: | ||
# directories: ./GNNLux/src, ./GNNLux/ext | ||
directories: ./GNNLux/src | ||
- uses: codecov/codecov-action@v4 | ||
with: | ||
files: lcov.info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Carlo Lucibello <[email protected]> and contributors | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
name = "GNNLux" | ||
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d" | ||
authors = ["Carlo Lucibello and contributors"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" | ||
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" | ||
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[compat] | ||
ConcreteStructs = "0.2.3" | ||
Lux = "0.5.61" | ||
LuxCore = "0.1.20" | ||
NNlib = "0.9.21" | ||
Reexport = "1.2" | ||
julia = "1.10" | ||
|
||
[extras] | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" | ||
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# GNNLux.jl | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
module GNNLux | ||
using ConcreteStructs: @concrete | ||
using NNlib: NNlib | ||
using LuxCore: LuxCore, AbstractExplicitLayer | ||
using Lux: glorot_uniform, zeros32 | ||
using Reexport: @reexport | ||
using Random: AbstractRNG | ||
using GNNlib: GNNlib | ||
@reexport using GNNGraphs | ||
|
||
include("layers/conv.jl") | ||
export GraphConv | ||
|
||
end #module | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
|
||
@doc raw""" | ||
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform) | ||
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244). | ||
Performs: | ||
```math | ||
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j | ||
``` | ||
where the aggregation type is selected by `aggr`. | ||
# Arguments | ||
- `in`: The dimension of input features. | ||
- `out`: The dimension of output features. | ||
- `σ`: Activation function. | ||
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). | ||
- `bias`: Add learnable bias. | ||
- `init`: Weights' initializer. | ||
# Examples | ||
```julia | ||
# create data | ||
s = [1,1,2,3] | ||
t = [2,3,1,1] | ||
in_channel = 3 | ||
out_channel = 5 | ||
g = GNNGraph(s, t) | ||
x = randn(Float32, 3, g.num_nodes) | ||
# create layer | ||
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) | ||
# forward pass | ||
y = l(g, x) | ||
``` | ||
""" | ||
@concrete struct GraphConv <: AbstractExplicitLayer | ||
in_dims::Int | ||
out_dims::Int | ||
use_bias::Bool | ||
init_weight::Function | ||
init_bias::Function | ||
σ | ||
aggr | ||
end | ||
|
||
|
||
function GraphConv(ch::Pair{Int, Int}, σ = identity; | ||
aggr = +, | ||
init_weight = glorot_uniform, | ||
init_bias = zeros32, | ||
use_bias::Bool = true, | ||
allow_fast_activation::Bool = true) | ||
in_dims, out_dims = ch | ||
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ | ||
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr) | ||
end | ||
|
||
function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv) | ||
weight1 = l.init_weight(rng, l.out_dims, l.in_dims) | ||
weight2 = l.init_weight(rng, l.out_dims, l.in_dims) | ||
if l.use_bias | ||
bias = l.init_bias(rng, l.out_dims) | ||
else | ||
bias = false | ||
end | ||
return (; weight1, weight2, bias) | ||
end | ||
|
||
function LuxCore.parameterlength(l::GraphConv) | ||
if l.use_bias | ||
return 2 * l.in_dims * l.out_dims + l.out_dims | ||
else | ||
return 2 * l.in_dims * l.out_dims | ||
end | ||
end | ||
|
||
LuxCore.statelength(d::GraphConv) = 0 | ||
LuxCore.outputsize(d::GraphConv) = (d.out_dims,) | ||
|
||
function Base.show(io::IO, l::GraphConv) | ||
print(io, "GraphConv(", l.in_dims, " => ", l.out_dims) | ||
(l.σ == identity) || print(io, ", ", l.σ) | ||
(l.aggr == +) || print(io, ", aggr=", l.aggr) | ||
l.use_bias || print(io, ", use_bias=false") | ||
print(io, ")") | ||
end | ||
|
||
(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
@testitem "layers/conv" setup=[SharedTestSetup] begin | ||
rng = StableRNG(1234) | ||
g = rand_graph(10, 30, seed=1234) | ||
x = randn(rng, Float32, 3, 10) | ||
|
||
@testset "GraphConv" begin | ||
l = GraphConv(3 => 5, relu) | ||
ps = Lux.initialparameters(rng, l) | ||
st = Lux.initialstates(rng, l) | ||
@test Lux.parameterlength(l) == Lux.parameterlength(ps) | ||
@test Lux.statelength(l) == Lux.statelength(st) | ||
|
||
y, _ = l(g, x, ps, st) | ||
@test Lux.outputsize(l) == (5,) | ||
@test size(y) == (5, 10) | ||
loss = (x, ps) -> sum(first(l(g, x, ps, st))) | ||
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Test | ||
using Lux | ||
using GNNLux | ||
using Random, Statistics | ||
|
||
using ReTestItems | ||
# using Pkg, Preferences, Test | ||
# using InteractiveUtils, Hwloc | ||
|
||
runtests(GNNLux) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
@testsetup module SharedTestSetup | ||
|
||
import Reexport: @reexport | ||
|
||
@reexport using Lux, Functors | ||
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test, | ||
Zygote, Statistics | ||
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx | ||
|
||
# Some Helper Functions | ||
function get_default_rng(mode::String) | ||
dev = mode == "cpu" ? LuxCPUDevice() : | ||
mode == "cuda" ? LuxCUDADevice() : mode == "amdgpu" ? LuxAMDGPUDevice() : nothing | ||
rng = default_device_rng(dev) | ||
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng) | ||
end | ||
|
||
export get_default_rng | ||
|
||
# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng, | ||
# StableRNG, maybe_rewrite_to_crosscor | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters