Skip to content

Commit

Permalink
fix heterograph docs
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 2, 2024
1 parent a48126d commit 553af1d
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 481 deletions.
12 changes: 5 additions & 7 deletions GNNGraphs/docs/src/api/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ Private = false

```@autodocs
Modules = [GNNGraphs]
Pages = ["query.jl"]
Pages = ["src/query.jl"]
Private = false
```


```@docs
Graphs.neighbors(::GNNGraph, ::Integer)
```
Expand All @@ -44,7 +43,7 @@ Graphs.neighbors(::GNNGraph, ::Integer)

```@autodocs
Modules = [GNNGraphs]
Pages = ["transform.jl"]
Pages = ["src/transform.jl"]
Private = false
```

Expand All @@ -59,17 +58,16 @@ GNNGraphs.color_refinement

```@autodocs
Modules = [GNNGraphs]
Pages = ["generate.jl"]
Pages = ["src/generate.jl"]
Private = false
Filter = t -> typeof(t) <: Function && t!=rand_temporal_radius_graph && t!=rand_temporal_hyperbolic_graph
```

## Operators

```@autodocs
Modules = [GNNGraphs]
Pages = ["operators.jl"]
Pages = ["src/operators.jl"]
Private = false
```

Expand All @@ -81,7 +79,7 @@ Base.intersect

```@autodocs
Modules = [GNNGraphs]
Pages = ["sampling.jl"]
Pages = ["src/sampling.jl"]
Private = false
```

Expand Down
5 changes: 0 additions & 5 deletions GNNGraphs/docs/src/api/heterograph.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@ CollapsedDocStrings = true
## GNNHeteroGraph
Documentation page for the type `GNNHeteroGraph` representing heterogeneous graphs, where nodes and edges can have different types.


```@autodocs
Modules = [GNNGraphs]
Pages = ["gnnheterograph.jl"]
Private = false
```

```@docs
Graphs.has_edge(::GNNHeteroGraph, ::Tuple{Symbol, Symbol, Symbol}, ::Integer, ::Integer)
```

## Query

```@autodocs
Expand Down
3 changes: 0 additions & 3 deletions GNNGraphs/docs/src/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ CollapsedDocStrings = true

# Samplers


## Docs

```@autodocs
Modules = [GNNGraphs]
Pages = ["samplers.jl"]
Expand Down
124 changes: 0 additions & 124 deletions GNNGraphs/src/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,130 +64,6 @@ function rand_graph(rng::AbstractRNG, n::Integer, m::Integer;
return GNNGraph((s, t, edge_weight); num_nodes=n, kws...)
end

"""
rand_heterograph([rng,] n, m; bidirected=false, kws...)
Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges
specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs
specifing node/edge types and their numbers.
Pass a random number generator as a first argument to make the generation reproducible.
Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge.
Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)`
will be generated.
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
# Examples
```jldoctest
julia> g = rand_heterograph((:user => 10, :movie => 20),
(:user, :rate, :movie) => 30)
GNNHeteroGraph:
num_nodes: Dict(:movie => 20, :user => 10)
num_edges: Dict((:user, :rate, :movie) => 30)
```
"""
function rand_heterograph end

# for generic iterators of pairs
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...)

function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...)
if seed != -1
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph)
rng = MersenneTwister(seed)
else
rng = Random.default_rng()
end
return rand_heterograph(rng, n, m; kws...)
end

function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...)
if bidirected
return _rand_bidirected_heterograph(rng, n, m; kws...)
end
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
end

function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...)
for k in keys(m)
if reverse(k) keys(m)
@assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs."
else
m[reverse(k)] = m[k]
end
end
graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}()
for k in keys(m)
reverse(k) keys(graphs) && continue
s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k])
graphs[k] = s, t, val
graphs[reverse(k)] = t, s, val
end
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
end


"""
rand_bipartite_heterograph([rng,]
(n1, n2), (m12, m21);
bidirected = true,
node_t = (:A, :B),
edge_t = :to,
kws...)
Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph.
The graph will have two types of nodes, and edges will only connect nodes of different types.
The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type.
The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2`
and vice versa.
The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments,
which default to `(:A, :B)` and `:to` respectively.
If `bidirected=true` (default), the reverse edge of each edge will be present. In this case
`m12 == m21` is required.
A random number generator can be passed as the first argument to make the generation reproducible.
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
See [`rand_heterograph`](@ref) for a more general version.
# Examples
```julia
julia> g = rand_bipartite_heterograph((10, 15), 20)
GNNHeteroGraph:
num_nodes: (:A => 10, :B => 15)
num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20)
julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false)
GNNHeteroGraph:
num_nodes: Dict(:item => 15, :user => 10)
num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20)
```
"""
rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...)

function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true,
node_t = (:A, :B), edge_t::Symbol = :to, kws...)
if m isa Integer
m12 = m21 = m
else
m12, m21 = m
end

return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2),
Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21);
bidirected, kws...)
end

"""
knn_graph(points::AbstractMatrix,
k::Int;
Expand Down
89 changes: 0 additions & 89 deletions GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,10 @@ edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2]

edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2]

"""
edge_index(g::GNNHeteroGraph, [edge_t])
Return a tuple containing two vectors, respectively storing the source and target nodes
for each edges in `g` of type `edge_t = (src_t, rel_t, trg_t)`.
If `edge_t` is not provided, it will error if `g` has more than one edge type.
"""
edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2]
edge_index(g::GNNHeteroGraph{<:COO_T}) = only(g.graph)[2][1:2]

get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]

get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][3]

get_edge_weight(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][3]

Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...)

Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)}
Expand All @@ -55,31 +42,6 @@ end

Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i, j] != 0

"""
has_edge(g::GNNHeteroGraph, edge_t, i, j)
Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in `g`.
# Examples
```jldoctest
julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
GNNHeteroGraph:
num_nodes: Dict(:A => 2, :B => 2)
num_edges: Dict((:A, :to, :B) => 4, (:B, :to, :A) => 0)
julia> has_edge(g, (:A,:to,:B), 1, 1)
true
julia> has_edge(g, (:B,:to,:A), 1, 1)
false
```
"""
function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Integer)
s, t = edge_index(g, edge_t)
return any((s .== i) .& (t .== j))
end

"""
get_graph_type(g::GNNGraph)
Expand Down Expand Up @@ -390,36 +352,6 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
return _degree(A, T, dir, edge_weight, g.num_nodes)
end

"""
degree(g::GNNHeteroGraph, edge_type::EType; dir = :in)
Return a vector containing the degrees of the nodes in `g` GNNHeteroGraph
given `edge_type`.
# Arguments
- `g`: A graph.
- `edge_type`: A tuple of symbols `(source_t, edge_t, target_t)` representing the edge type.
- `T`: Element type of the returned vector. If `nothing`, is
chosen based on the graph type. Default `nothing`.
- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges.
For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two.
Default `dir = :out`.
"""
function Graphs.degree(g::GNNHeteroGraph, edge::EType,
T::TT = nothing; dir = :out) where {
TT <: Union{Nothing, Type{<:Number}}}

s, t = edge_index(g, edge)

T = isnothing(T) ? eltype(s) : T

n_type = dir == :in ? g.ntypes[2] : g.ntypes[1]

return _degree((s, t), T, dir, nothing, g.num_nodes[n_type])
end

function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_nodes::Int)
_degree((s, t), T, dir, ones_like(s, T), num_nodes)
end
Expand Down Expand Up @@ -579,28 +511,7 @@ function graph_indicator(g::GNNGraph; edges = false)
end
end

"""
graph_indicator(g::GNNHeteroGraph, [node_t])
Return a Dict of vectors containing the graph membership
(an integer from `1` to `g.num_graphs`) of each node in the graph for each node type.
If `node_t` is provided, return the graph membership of each node of type `node_t` instead.

See also [`batch`](@ref).
"""
function graph_indicator(g::GNNHeteroGraph)
return g.graph_indicator
end

function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
@assert node_t g.ntypes
if isnothing(g.graph_indicator)
gi = ones_like(edge_index(g, first(g.etypes))[1], Int, g.num_nodes[node_t])
else
gi = g.graph_indicator[node_t]
end
return gi
end

function node_features(g::GNNGraph)
if isempty(g.ndata)
Expand Down
Loading

0 comments on commit 553af1d

Please sign in to comment.