diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index b82ea39eb..055ea2789 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -4,7 +4,7 @@ using SparseArrays using Functors: @functor import Graphs using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, - has_self_loops, is_directed + has_self_loops, is_directed, induced_subgraph import NearestNeighbors import NNlib import StatsBase diff --git a/GNNGraphs/src/sampling.jl b/GNNGraphs/src/sampling.jl index 01a601f5b..7e723182a 100644 --- a/GNNGraphs/src/sampling.jl +++ b/GNNGraphs/src/sampling.jl @@ -116,3 +116,87 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; end return gnew end + + +""" + induced_subgraph(graph, nodes) + +Generates a subgraph from the original graph using the provided `nodes`. +The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph. +If a node has no neighbors, an isolated node will be added to the subgraph. +Returns A new `GNNGraph` containing the subgraph with the specified nodes and their features. + +# Arguments + +- `graph`. The original GNNGraph containing nodes, edges, and node features. +- `nodes``. A vector of node indices to include in the subgraph. + +# Examples + +```julia +julia> s = [1, 2] +2-element Vector{Int64}: + 1 + 2 + +julia> t = [2, 3] +2-element Vector{Int64}: + 2 + 3 + +julia> graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2)) +GNNGraph: + num_nodes: 3 + num_edges: 2 + ndata: + y = 3-element Vector{Float32} + x = 32×3 Matrix{Float32} + edata: + e = 2-element Vector{Float32} + +julia> nodes = [1, 2] +2-element Vector{Int64}: + 1 + 2 + +julia> subgraph = Graphs.induced_subgraph(graph, nodes) +GNNGraph: + num_nodes: 2 + num_edges: 1 + ndata: + y = 2-element Vector{Float32} + x = 32×2 Matrix{Float32} + edata: + e = 1-element Vector{Float32} +``` +""" +function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) + if isempty(nodes) + return GNNGraph() # Return empty graph if no nodes are provided + end + + node_map = Dict(node => i for (i, node) in enumerate(nodes)) + + # Collect edges to add + source = Int[] + target = Int[] + eindices = Int[] + for node in nodes + neighbors = Graphs.neighbors(graph, node, dir = :in) + for neighbor in neighbors + if neighbor in keys(node_map) + push!(target, node_map[node]) + push!(source, node_map[neighbor]) + + eindex = findfirst(x -> x == [neighbor, node], edge_index(graph)) + push!(eindices, eindex) + end + end + end + + # Extract features for the new nodes + new_ndata = getobs(graph.ndata, nodes) + new_edata = getobs(graph.edata, eindices) + + return GNNGraph(source, target, num_nodes = length(node_map), ndata = new_ndata, edata = new_edata) +end diff --git a/GNNGraphs/test/sampling.jl b/GNNGraphs/test/sampling.jl index 658cee8da..9b4cd5b16 100644 --- a/GNNGraphs/test/sampling.jl +++ b/GNNGraphs/test/sampling.jl @@ -45,4 +45,36 @@ if GRAPH_T == :coo @test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID] @test length(union(sg.ndata.NID)) == length(sg.ndata.NID) end + + @testset "induced_subgraph" begin + s = [1, 2] + t = [2, 3] + + graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2)) + + nodes = [1, 2, 3] + subgraph = Graphs.induced_subgraph(graph, nodes) + + @test subgraph.num_nodes == 3 + @test subgraph.num_edges == 2 + @test subgraph.ndata.x == graph.ndata.x + @test subgraph.ndata.y == graph.ndata.y + @test subgraph.edata == graph.edata + + nodes = [1, 2] + subgraph = Graphs.induced_subgraph(graph, nodes) + + @test subgraph.num_nodes == 2 + @test subgraph.num_edges == 1 + @test subgraph.ndata == getobs(graph.ndata, [1, 2]) + @test isapprox(getobs(subgraph.edata.e, 1), getobs(graph.edata.e, 1); atol=1e-6) + + graph = GNNGraph(2) + graph = add_edges(graph, ([2], [1])) + nodes = [1] + subgraph = Graphs.induced_subgraph(graph, nodes) + + @test subgraph.num_nodes == 1 + @test subgraph.num_edges == 0 + end end \ No newline at end of file diff --git a/GraphNeuralNetworks/docs/src/api/gnngraph.md b/GraphNeuralNetworks/docs/src/api/gnngraph.md index de6fc1872..f708c3840 100644 --- a/GraphNeuralNetworks/docs/src/api/gnngraph.md +++ b/GraphNeuralNetworks/docs/src/api/gnngraph.md @@ -88,3 +88,7 @@ Modules = [GNNGraphs] Pages = ["sampling.jl"] Private = false ``` + +```@docs +Graphs.induced_subgraph(::GNNGraph, ::Vector{Int}) +``` \ No newline at end of file diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index cebf7b7d3..c9a227b8d 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -11,7 +11,7 @@ using ChainRulesCore using Reexport using MLUtils: zeros_like -using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, +using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, EType, NType # for heteroconvs