From 55fe50b8d2a4f5af3036874601ed951abedd4a6e Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 30 Nov 2024 11:55:32 +0100 Subject: [PATCH] get_graph_type --- GNNGraphs/src/GNNGraphs.jl | 1 + GNNGraphs/src/operators.jl | 4 +- GNNGraphs/src/query.jl | 58 +++++++++++++++++++++++-- GNNGraphs/test/query.jl | 17 ++++++++ GraphNeuralNetworks/test/test_module.jl | 7 ++- 5 files changed, 81 insertions(+), 6 deletions(-) diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index 5a5b5fe66..3054a9ab8 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -47,6 +47,7 @@ include("query.jl") export adjacency_list, edge_index, get_edge_weight, + get_graph_type, graph_indicator, has_multi_edges, is_directed, diff --git a/GNNGraphs/src/operators.jl b/GNNGraphs/src/operators.jl index 4fdd6ac87..1faa4adcb 100644 --- a/GNNGraphs/src/operators.jl +++ b/GNNGraphs/src/operators.jl @@ -6,8 +6,8 @@ Intersect two graphs by keeping only the common edges. """ function Base.intersect(g1::GNNGraph, g2::GNNGraph) @assert g1.num_nodes == g2.num_nodes - @assert graph_type_symbol(g1) == graph_type_symbol(g2) - graph_type = graph_type_symbol(g1) + @assert get_graph_type(g1) == get_graph_type(g2) + graph_type = get_graph_type(g1) num_nodes = g1.num_nodes idx1, _ = edge_encoding(edge_index(g1)..., num_nodes) diff --git a/GNNGraphs/src/query.jl b/GNNGraphs/src/query.jl index 18622af21..a11eb564c 100644 --- a/GNNGraphs/src/query.jl +++ b/GNNGraphs/src/query.jl @@ -80,9 +80,61 @@ function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Intege return any((s .== i) .& (t .== j)) end -graph_type_symbol(::GNNGraph{<:COO_T}) = :coo -graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse -graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense +""" + get_graph_type(g::GNNGraph) + +Return the underlying representation for the graph `g` as a symbol. + +Possible values are: +- `:coo`: Coordinate list representation. The graph is stored as a tuple of vectors `(s, t, w)`, + where `s` and `t` are the source and target nodes of the edges, and `w` is the edge weights. +- `:sparse`: Sparse matrix representation. The graph is stored as a sparse matrix representing the weighted adjacency matrix. +- `:dense`: Dense matrix representation. The graph is stored as a dense matrix representing the weighted adjacency matrix. + +The default representation for graph constructors GNNGraphs.jl is `:coo`. +The underlying representation can be accessed through the `g.graph` field. + +See also [`GNNGraph`](@ref). + +# Examples + +The default representation for graph constructors GNNGraphs.jl is `:coo`. +```jldoctest +julia> g = rand_graph(5, 10) +GNNGraph: + num_nodes: 5 + num_edges: 10 + +julia> get_graph_type(g) +:coo +``` +The `GNNGraph` constructor can also be used to create graphs with different representations. +```jldoctest +julia> g = GNNGraph([2,3,5], [1,2,4], graph_type=:sparse) +GNNGraph: + num_nodes: 5 + num_edges: 3 + +julia> g.graph +5×5 SparseArrays.SparseMatrixCSC{Int64, Int64} with 3 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + 1 ⋅ ⋅ ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + +julia> get_graph_type(g) +:sparse + +julia> gcoo = GNNGraph(g, graph_type=:coo); + +julia> gcoo.graph +([2, 3, 5], [1, 2, 4], [1, 1, 1]) +``` +""" +get_graph_type(::GNNGraph{<:COO_T}) = :coo +get_graph_type(::GNNGraph{<:SPARSE_T}) = :sparse +get_graph_type(::GNNGraph{<:ADJMAT_T}) = :dense Graphs.nv(g::GNNGraph) = g.num_nodes Graphs.ne(g::GNNGraph) = g.num_edges diff --git a/GNNGraphs/test/query.jl b/GNNGraphs/test/query.jl index e7f55e76a..a345ae779 100644 --- a/GNNGraphs/test/query.jl +++ b/GNNGraphs/test/query.jl @@ -257,3 +257,20 @@ if GRAPH_T == :coo end end +@testset "get_graph_type" begin + g = rand_graph(10, 20, graph_type = GRAPH_T) + @test get_graph_type(g) == GRAPH_T + + gsparse = GNNGraph(g, graph_type=:sparse) + @test get_graph_type(gsparse) == :sparse + @test gsparse.graph isa SparseMatrixCSC + + gcoo = GNNGraph(g, graph_type=:coo) + @test get_graph_type(gcoo) == :coo + @test gcoo.graph[1:2] isa Tuple{Vector{Int}, Vector{Int}} + + + gdense = GNNGraph(g, graph_type=:dense) + @test get_graph_type(gdense) == :dense + @test gdense.graph isa Matrix{Int} +end diff --git a/GraphNeuralNetworks/test/test_module.jl b/GraphNeuralNetworks/test/test_module.jl index bc515880a..87e116a1e 100644 --- a/GraphNeuralNetworks/test/test_module.jl +++ b/GraphNeuralNetworks/test/test_module.jl @@ -59,15 +59,20 @@ function finitediff_withgradient(f, x...) end function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) + equal = true fmapstructure_with_path(a, b) do kp, x, y if x isa AbstractArray # @show kp - @assert x ≈ y rtol=rtol atol=atol + # @assert x ≈ y rtol=rtol atol=atol + if !isapprox(x, y; rtol, atol) + equal = false + end # elseif x isa Number # @show kp # @assert x ≈ y rtol=rtol atol=atol end end + @assert equal end function test_gradients(