Skip to content

Commit

Permalink
feat: add induced_subgraph functionality (#499)
Browse files Browse the repository at this point in the history
* feat: add induced_subgraph functionality

* fix: fix tests

* fix: fix tests

* Update GNNGraphs/src/sampling.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update GNNGraphs/src/GNNGraphs.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update GNNGraphs/src/sampling.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update GNNGraphs/src/sampling.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* feat: add edata&ndata support for induced_subgraph

* chore: export induced_subgraph

* fix: fix naming for induced_subgraph

* fix: fix typo

* fix: revert naming for induced_subgraph

* fix: fix test

* fix: don't export induced subgraph

* fix: don't export induced subgraph

* fix: amend docstring

* fix: amend docstring

* fix: fix docstring

* fix: add Graphs.induced_subgraph to docs

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
askorupka and CarloLucibello authored Oct 12, 2024
1 parent 3fe2c76 commit 8dbac17
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 2 deletions.
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseArrays
using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed
has_self_loops, is_directed, induced_subgraph
import NearestNeighbors
import NNlib
import StatsBase
Expand Down
84 changes: 84 additions & 0 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,87 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
end
return gnew
end


"""
induced_subgraph(graph, nodes)
Generates a subgraph from the original graph using the provided `nodes`.
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
If a node has no neighbors, an isolated node will be added to the subgraph.
Returns A new `GNNGraph` containing the subgraph with the specified nodes and their features.
# Arguments
- `graph`. The original GNNGraph containing nodes, edges, and node features.
- `nodes``. A vector of node indices to include in the subgraph.
# Examples
```julia
julia> s = [1, 2]
2-element Vector{Int64}:
1
2
julia> t = [2, 3]
2-element Vector{Int64}:
2
3
julia> graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2))
GNNGraph:
num_nodes: 3
num_edges: 2
ndata:
y = 3-element Vector{Float32}
x = 32×3 Matrix{Float32}
edata:
e = 2-element Vector{Float32}
julia> nodes = [1, 2]
2-element Vector{Int64}:
1
2
julia> subgraph = Graphs.induced_subgraph(graph, nodes)
GNNGraph:
num_nodes: 2
num_edges: 1
ndata:
y = 2-element Vector{Float32}
x = 32×2 Matrix{Float32}
edata:
e = 1-element Vector{Float32}
```
"""
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if isempty(nodes)
return GNNGraph() # Return empty graph if no nodes are provided
end

node_map = Dict(node => i for (i, node) in enumerate(nodes))

# Collect edges to add
source = Int[]
target = Int[]
eindices = Int[]
for node in nodes
neighbors = Graphs.neighbors(graph, node, dir = :in)
for neighbor in neighbors
if neighbor in keys(node_map)
push!(target, node_map[node])
push!(source, node_map[neighbor])

eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
push!(eindices, eindex)
end
end
end

# Extract features for the new nodes
new_ndata = getobs(graph.ndata, nodes)
new_edata = getobs(graph.edata, eindices)

return GNNGraph(source, target, num_nodes = length(node_map), ndata = new_ndata, edata = new_edata)
end
32 changes: 32 additions & 0 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,36 @@ if GRAPH_T == :coo
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end

@testset "induced_subgraph" begin
s = [1, 2]
t = [2, 3]

graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2))

nodes = [1, 2, 3]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 3
@test subgraph.num_edges == 2
@test subgraph.ndata.x == graph.ndata.x
@test subgraph.ndata.y == graph.ndata.y
@test subgraph.edata == graph.edata

nodes = [1, 2]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 2
@test subgraph.num_edges == 1
@test subgraph.ndata == getobs(graph.ndata, [1, 2])
@test isapprox(getobs(subgraph.edata.e, 1), getobs(graph.edata.e, 1); atol=1e-6)

graph = GNNGraph(2)
graph = add_edges(graph, ([2], [1]))
nodes = [1]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 1
@test subgraph.num_edges == 0
end
end
4 changes: 4 additions & 0 deletions GraphNeuralNetworks/docs/src/api/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ Modules = [GNNGraphs]
Pages = ["sampling.jl"]
Private = false
```

```@docs
Graphs.induced_subgraph(::GNNGraph, ::Vector{Int})
```
2 changes: 1 addition & 1 deletion GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ChainRulesCore
using Reexport
using MLUtils: zeros_like

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
EType, NType # for heteroconvs

Expand Down

0 comments on commit 8dbac17

Please sign in to comment.