Skip to content

Commit

Permalink
use GNNlib in GNN.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 28, 2024
1 parent 80c672a commit b0e5dd9
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 435 deletions.
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
module GNNlibCUDAExt

using CUDA
using Random, Statistics, LinearAlgebra
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
using GNNGraphs: GNNGraph, COO_T, SPARSE_T

###### PROPAGATE SPECIALIZATIONS ####################

## COPY_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e)
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e)
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
end

## E_MUL_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::AbstractVector)
function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::AbstractVector)
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
end

## W_MUL_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::Nothing)
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::Nothing)
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
end

# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# A = adjacency_matrix(g, weighted=false)
# D = compute_degree(A)
# return xj * A * D
Expand All @@ -35,3 +41,5 @@ end
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))

# Flux.Zygote.@nograd compute_degree

end #module
11 changes: 0 additions & 11 deletions GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl

This file was deleted.

10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -20,8 +21,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
GraphNeuralNetworksCUDAExt = "CUDA"
# [extensions]
# GraphNeuralNetworksCUDAExt = "CUDA"

[compat]
CUDA = "4, 5"
Expand All @@ -30,9 +31,10 @@ DataStructures = "0.18"
Flux = "0.14"
Functors = "0.4.1"
GNNGraphs = "1.0"
GNNlib = "0.2"
LinearAlgebra = "1"
MacroTools = "0.5"
MLUtils = "0.4"
MacroTools = "0.5"
NNlib = "0.9"
Random = "1"
Reexport = "1"
Expand Down
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using ChainRulesCore
using Reexport
using DataStructures: nlargest
using MLUtils: zeros_like
using GNNlib: GNNlib

@reexport using GNNGraphs
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
Expand Down
4 changes: 3 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@

@deprecate AGNNConv(init_beta) AGNNConv(; init_beta)
# V1.0 deprecations
# TODO doe some reason this is not working
# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight)
Loading

0 comments on commit b0e5dd9

Please sign in to comment.