Skip to content

Commit

Permalink
add buildkite workflow (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Nov 27, 2024
1 parent 301602e commit d8d69ec
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 54 deletions.
25 changes: 25 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
steps:
- label: "GNN CUDA"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-coverage#v1:
dirs:
- GraphNeuralNetworks/src
command: |
julia --color=yes --depwarn=yes --project=GraphNeuralNetworks/test -e '
import Pkg
dev_pkgs = Pkg.PackageSpec[]
for pkg in ("GNNGraphs", "GNNlib", "GraphNeuralNetworks")
push!(dev_pkgs, Pkg.PackageSpec(path=pkg));
end
Pkg.develop(dev_pkgs)
Pkg.add(["CUDA", "cuDNN"])
Pkg.test("GraphNeuralNetworks")'
agents:
queue: "juliagpu"
cuda: "*"
env:
GNN_TEST_CUDA: "true"
GNN_TEST_CPU: "false"
timeout_in_minutes: 60
6 changes: 6 additions & 0 deletions GNNGraphs/test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA
CUDA.allowscalar(false)

# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
function FiniteDifferences.to_vec(x::Integer)
Integer_from_vec(v) = x
return Int[x], Integer_from_vec
end

function ngradient(f, x...)
fdm = central_fdm(5, 1)
return FiniteDifferences.grad(fdm, f, x...)
Expand Down
27 changes: 0 additions & 27 deletions GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ version = "0.6.22"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -18,41 +16,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[compat]
CUDA = "4, 5"
ChainRulesCore = "1"
Flux = "0.14"
Functors = "0.4.1"
GNNGraphs = "1.0"
GNNlib = "0.2"
Graphs = "1.12"
LinearAlgebra = "1"
MLUtils = "0.4"
MacroTools = "0.5"
NNlib = "0.9"
Pkg = "1"
Random = "1"
Reexport = "1"
Statistics = "1"
TestItemRunner = "1.0.5"
julia = "1.10"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "TestItemRunner", "Pkg", "MLDatasets", "Adapt", "DataFrames", "InlineStrings", "SparseArrays", "Graphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils"]
44 changes: 38 additions & 6 deletions GraphNeuralNetworks/docs/src/dev.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,62 @@
# Developer Notes

## Develop and Managing the Monorepo

### Development Enviroment
## Development Enviroment
GraphNeuralNetworks.jl is package hosted in a monorepo that contains multiple packages.
The GraphNeuralNetworks.jl package depends on GNNGraphs.jl, also hosted in the same monorepo.
The GraphNeuralNetworks.jl package depends on GNNGraphs.jl and GNNlib.jl, also hosted in the same monorepo.
In order

```julia
pkg> activate .

pkg> dev ./GNNGraphs
```

### Add a New Layer
## Add a New Layer

To add a new graph convolutional layer and make it available in both the Flux-based frontend (GraphNeuralNetworks.jl) and the Lux-based frontend (GNNLux), you need to:

1. Add the functional version to GNNlib
2. Add the stateful version to GraphNeuralNetworks
3. Add the stateless version to GNNLux
4. Add the layer to the table in docs/api/conv.md

### Versions and Tagging
We suggest to start with implementing a self-contained Flux layer in GraphNeuralNetworks.jl, add the corresponding tests, and then when everything is working, move the implementation of the forward pass to GNNlib.jl. At this point, you can add the stateless version to GNNLux.jl.

It could also be convenient to use the `@structdef` macro from [Autostruct.jl](https://github.com/CarloLucibello/AutoStructs.jl) to simultaneously generate the struct and the constructor for the layer.
For example, the Flux implementation of [`MEGNetConv`](@ref) layer can be written as follows:

```julia
using Flux, GraphNeuralNetworks, AutoStructs

@structdef function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
nin, nout = ch
ϕe = Chain(Dense(3nin, nout, relu),
Dense(nout, nout))

ϕv = Chain(Dense(nin + nout, nout, relu),
Dense(nout, nout))

return MEGNetConv(ϕe, ϕv, aggr)
end

Flux.@layer MEGNetConv

function (l::MEGNetConv)(g::AbstractGraph, x::AbstractMatrix, e::AbstractMatrix)
= apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e
l.ϕe(vcat(xi, xj, e))
end
xᵉ = aggregate_neighbors(g, l.aggr, ē)
= l.ϕv(vcat(x, xᵉ))
return x̄, ē
end
```

## Versions and Tagging
Each PR should update the version number in the Porject.toml file of each involved package if needed by semnatic versioning. For instance, when adding new features GNNGraphs could move from "1.17.5" to "1.18.0-DEV". The "DEV" will be removed when the package is tagged and released. Pay also attention to updating
the compat bounds, e.g. GraphNeuralNetworks might require a newer version of GNNGraphs.

### Generate Documentation Locally
## Generate Documentation Locally
For generating the documentation locally
```
cd docs
Expand Down
4 changes: 1 addition & 3 deletions GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ using Flux
using Flux: glorot_uniform, leakyrelu, GRUCell, batch
using MacroTools: @forward
using NNlib
using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using Reexport: @reexport
using MLUtils: zeros_like
using Graphs: Graphs

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
Expand Down
15 changes: 15 additions & 0 deletions GraphNeuralNetworks/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ end # module
TrainingExampleModule.train_many()
end

@testitem "training example GPU" setup=[TrainingExampleModule] tags=[:gpu] begin
@testitem "training example GPU" setup=[TestModule, TrainingExampleModule] tags=[:gpu] begin
using .TestModule # for loading gpu packages
using .TrainingExampleModule
TrainingExampleModule.train_many(use_gpu = true)
end
Expand Down
5 changes: 5 additions & 0 deletions GraphNeuralNetworks/test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ end
k = 2
l = ChebConv(D_IN => D_OUT, k)
for g in TEST_GRAPHS
has_isolated_nodes(g) && continue
g.graph isa AbstractSparseMatrix && continue
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_gpu = true, compare_finite_diff = false)
Expand Down Expand Up @@ -377,6 +378,7 @@ end
l = CGConv((D_IN, edim) => D_OUT, tanh, residual = false, bias = true)
for g in TEST_GRAPHS
g.graph isa AbstractSparseMatrix && continue
g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges))
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, g.e, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false)
end
Expand Down Expand Up @@ -432,6 +434,7 @@ end
l = MEGNetConv(D_IN => D_OUT, aggr = +)
for g in TEST_GRAPHS
g.graph isa AbstractSparseMatrix && continue
g = GNNGraph(g, edata = rand(Float32, D_IN, g.num_edges))
y = l(g, g.x, g.e)
@test size(y[1]) == (D_OUT, g.num_nodes)
@test size(y[2]) == (D_OUT, g.num_edges)
Expand Down Expand Up @@ -462,6 +465,7 @@ end
l = GMMConv((D_IN, ein_channel) => D_OUT, K = K)
for g in TEST_GRAPHS
g.graph isa AbstractSparseMatrix && continue
g = GNNGraph(g, edata = rand(Float32, ein_channel, g.num_edges))
y = l(g, g.x, g.e)
test_gradients(l, g, g.x, g.e, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false)
end
Expand Down Expand Up @@ -585,6 +589,7 @@ end
bias_qkv = true)
for g in TEST_GRAPHS
g.graph isa AbstractSparseMatrix && continue
g = GNNGraph(g, edata = rand(Float32, ein, g.num_edges))
@test size(l(g, g.x, g.e)) == (D_IN * heads, g.num_nodes)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_gpu = true, compare_finite_diff = false)
end
Expand Down
6 changes: 5 additions & 1 deletion GraphNeuralNetworks/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ using TestItemRunner
## for how to run the tests within VS Code.
## See test_module.jl for the test infrastructure.

## Uncomment below to change the default test settings
## Uncomment below and in test_module.jl to change the default test settings
# ENV["GNN_TEST_CPU"] = "false"
# ENV["GNN_TEST_CUDA"] = "true"
# ENV["GNN_TEST_AMDGPU"] = "true"
# ENV["GNN_TEST_Metal"] = "true"

# The only available tag at the moment is :gpu
# Tests not tagged with :gpu are considered to be CPU tests
# Tests tagged with :gpu should run on all GPU backends

if get(ENV, "GNN_TEST_CPU", "true") == "true"
@run_package_tests filter = ti -> :gpu ti.tags
end
Expand Down
15 changes: 9 additions & 6 deletions GraphNeuralNetworks/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
using GraphNeuralNetworks
using Test
using Statistics, Random
using Flux, Functors
using Flux
using Functors: fmapstructure_with_path
using Graphs
using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt
using ChainRulesTestUtils, FiniteDifferences
using Zygote
using SparseArrays
using Pkg

Expand All @@ -16,22 +18,23 @@ using Pkg
# ENV["GNN_TEST_Metal"] = "true"

if get(ENV, "GNN_TEST_CUDA", "false") == "true"
Pkg.add(["CUDA", "cuDNN"])
# Pkg.add(["CUDA", "cuDNN"])
using CUDA
CUDA.allowscalar(false)
end
if get(ENV, "GNN_TEST_AMDGPU", "false") == "true"
Pkg.add("AMDGPU")
# Pkg.add("AMDGPU")
using AMDGPU
AMDGPU.allowscalar(false)
end
if get(ENV, "GNN_TEST_Metal", "false") == "true"
Pkg.add("Metal")
# Pkg.add("Metal")
using Metal
Metal.allowscalar(false)
end

# from Bse

# from Base
export mean, randn, SparseArrays, AbstractSparseMatrix

# from other packages
Expand Down
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,30 @@
![](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg)
[![codecov](https://codecov.io/gh/JuliaGraphs/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaGraphs/GraphNeuralNetworks.jl)

This is the monorepository for the GraphNeuralNetworks project, bringing together all code into a unified structure to facilitate code sharing and reusability across different project components. It contains the following packages:
Libraries for deep learning on graphs in Julia, using either [Flux.jl](https://fluxml.ai/Flux.jl/stable/) or [Lux.jl](https://lux.csail.mit.edu/stable/) as backend framework.

- `GraphNeuralNetwork.jl`: Package that contains stateful graph convolutional layers based on the machine learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/). This is the fronted package for Flux users. It depends on GNNlib.jl, GNNGraphs.jl, and Flux.jl packages.
This monorepo contains the following packages:

- `GNNLux.jl`: Package that contains stateless graph convolutional layers based on the machine learning framework [Lux.jl](https://lux.csail.mit.edu/stable/). This is fronted package for Lux users. It depends on GNNlib.jl, GNNGraphs.jl, and Lux.jl packages.
- `GraphNeuralNetworks.jl`: Graph convolutional layers based on the deep learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/). This is the fronted package for Flux users.

- `GNNlib.jl`: Package that contains the core graph neural network layers and utilities. It depends on GNNGraphs.jl and GNNlib.jl packages and serves for code base for GraphNeuralNetwork.jl and GNNLux.jl packages.
- `GNNLux.jl`: Graph convolutional layers based on the deep learning framework [Lux.jl](https://lux.csail.mit.edu/stable/). This is the fronted package for Lux users. This package is still under development and it is not yet registered.

- `GNNGraphs.jl`: Package that contains the graph data structures and helper functions for working with graph data. It depends on Graphs.jl package.
- `GNNlib.jl`: Contains the message passing framework based on the gather/scatter mechanism or on
sparse matrix multiplication. It also contained the shared implementation for the layers of the two fronted packages. This package is not meant to be used directly by the user, but its functionalities
are used and re-exported by the fronted packages.

- `GNNGraphs.jl`: Package that contains the graph data structures and helper functions for working with graph data. It depends on Graphs.jl package.


Among its general features:
Both `GraphNeuralNetworks.jl` and `GNNLux.jl` enjoy several features:

* Implements common graph convolutional layers both in stateful and stateless form.
* Supports computations on batched graphs.
* Implement common graph convolutional layers.
* Support computations on batched graphs.
* Easy to define custom layers.
* CUDA support.
* CUDA and AMDGPU support.
* Integration with [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl).
* [Examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/GraphNeuralNetworks/examples) of node, edge, and graph level machine learning tasks.
* Heterogeneous and temporal graphs.
* Heterogeneous and temporal graphs support.

## Installation

Expand Down

0 comments on commit d8d69ec

Please sign in to comment.