Skip to content

Commit

Permalink
get_graph_type
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 30, 2024
1 parent ebab567 commit 55fe50b
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 6 deletions.
1 change: 1 addition & 0 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include("query.jl")
export adjacency_list,
edge_index,
get_edge_weight,
get_graph_type,
graph_indicator,
has_multi_edges,
is_directed,
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ Intersect two graphs by keeping only the common edges.
"""
function Base.intersect(g1::GNNGraph, g2::GNNGraph)
@assert g1.num_nodes == g2.num_nodes
@assert graph_type_symbol(g1) == graph_type_symbol(g2)
graph_type = graph_type_symbol(g1)
@assert get_graph_type(g1) == get_graph_type(g2)
graph_type = get_graph_type(g1)
num_nodes = g1.num_nodes

idx1, _ = edge_encoding(edge_index(g1)..., num_nodes)
Expand Down
58 changes: 55 additions & 3 deletions GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,61 @@ function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Intege
return any((s .== i) .& (t .== j))
end

graph_type_symbol(::GNNGraph{<:COO_T}) = :coo
graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse
graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense
"""
get_graph_type(g::GNNGraph)
Return the underlying representation for the graph `g` as a symbol.
Possible values are:
- `:coo`: Coordinate list representation. The graph is stored as a tuple of vectors `(s, t, w)`,
where `s` and `t` are the source and target nodes of the edges, and `w` is the edge weights.
- `:sparse`: Sparse matrix representation. The graph is stored as a sparse matrix representing the weighted adjacency matrix.
- `:dense`: Dense matrix representation. The graph is stored as a dense matrix representing the weighted adjacency matrix.
The default representation for graph constructors GNNGraphs.jl is `:coo`.
The underlying representation can be accessed through the `g.graph` field.
See also [`GNNGraph`](@ref).
# Examples
The default representation for graph constructors GNNGraphs.jl is `:coo`.
```jldoctest
julia> g = rand_graph(5, 10)
GNNGraph:
num_nodes: 5
num_edges: 10
julia> get_graph_type(g)
:coo
```
The `GNNGraph` constructor can also be used to create graphs with different representations.
```jldoctest
julia> g = GNNGraph([2,3,5], [1,2,4], graph_type=:sparse)
GNNGraph:
num_nodes: 5
num_edges: 3
julia> g.graph
5×5 SparseArrays.SparseMatrixCSC{Int64, Int64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
1 ⋅ ⋅ ⋅ ⋅
⋅ 1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1 ⋅
julia> get_graph_type(g)
:sparse
julia> gcoo = GNNGraph(g, graph_type=:coo);
julia> gcoo.graph
([2, 3, 5], [1, 2, 4], [1, 1, 1])
```
"""
get_graph_type(::GNNGraph{<:COO_T}) = :coo
get_graph_type(::GNNGraph{<:SPARSE_T}) = :sparse
get_graph_type(::GNNGraph{<:ADJMAT_T}) = :dense

Graphs.nv(g::GNNGraph) = g.num_nodes
Graphs.ne(g::GNNGraph) = g.num_edges
Expand Down
17 changes: 17 additions & 0 deletions GNNGraphs/test/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,20 @@ if GRAPH_T == :coo
end
end

@testset "get_graph_type" begin
g = rand_graph(10, 20, graph_type = GRAPH_T)
@test get_graph_type(g) == GRAPH_T

gsparse = GNNGraph(g, graph_type=:sparse)
@test get_graph_type(gsparse) == :sparse
@test gsparse.graph isa SparseMatrixCSC

gcoo = GNNGraph(g, graph_type=:coo)
@test get_graph_type(gcoo) == :coo
@test gcoo.graph[1:2] isa Tuple{Vector{Int}, Vector{Int}}


gdense = GNNGraph(g, graph_type=:dense)
@test get_graph_type(gdense) == :dense
@test gdense.graph isa Matrix{Int}
end
7 changes: 6 additions & 1 deletion GraphNeuralNetworks/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,20 @@ function finitediff_withgradient(f, x...)
end

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
equal = true
fmapstructure_with_path(a, b) do kp, x, y
if x isa AbstractArray
# @show kp
@assert x y rtol=rtol atol=atol
# @assert x ≈ y rtol=rtol atol=atol
if !isapprox(x, y; rtol, atol)
equal = false
end
# elseif x isa Number
# @show kp
# @assert x ≈ y rtol=rtol atol=atol
end
end
@assert equal
end

function test_gradients(
Expand Down

0 comments on commit 55fe50b

Please sign in to comment.