diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index b82ea39eb..2d9850f4c 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -97,6 +97,7 @@ export rand_graph, include("sampling.jl") export sample_neighbors +export induced_subgraph include("operators.jl") # Base.intersect diff --git a/GNNGraphs/src/sampling.jl b/GNNGraphs/src/sampling.jl index 01a601f5b..48efa44d5 100644 --- a/GNNGraphs/src/sampling.jl +++ b/GNNGraphs/src/sampling.jl @@ -116,3 +116,53 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; end return gnew end + +""" + induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph + +Generates a subgraph from the original graph using the provided `nodes`. + The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph. + If a node has no neighbors, an isolated node will be added to the subgraph. + +# Arguments: +- `graph::GNNGraph`: The original graph containing nodes, edges, and node features. +- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph. + +# Returns: +A new `GNNGraph` containing the subgraph with the specified nodes and their features. +""" +function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) + if isempty(nodes) + return GNNGraph() # Return empty graph if no nodes are provided + end + + node_map = Dict(node => i for (i, node) in enumerate(nodes)) + + # Collect edges to add + source = Int[] + target = Int[] + backup_gnn = GNNGraph() + for node in nodes + neighbors = Graphs.neighbors(graph, node, dir = :in) + if isempty(neighbors) + backup_gnn = add_nodes(backup_gnn, 1) + end + for neighbor in neighbors + if neighbor in keys(node_map) + push!(source, node_map[node]) + push!(target, node_map[neighbor]) + end + end + end + + # Extract features for the new nodes + #new_features = graph.x[:, nodes] + + if isempty(source) && isempty(target) + #backup_gnn.ndata.x = new_features ### TODO fix & add edges data (probably push themto the new vector?) + return backup_gnn # Return empty graph if no nodes are provided + end + + return GNNGraph(source, target) + #, ndata = new_features) # Return the new GNNGraph with subgraph and features +end diff --git a/GNNGraphs/test/sampling.jl b/GNNGraphs/test/sampling.jl index 658cee8da..284589ee9 100644 --- a/GNNGraphs/test/sampling.jl +++ b/GNNGraphs/test/sampling.jl @@ -45,4 +45,20 @@ if GRAPH_T == :coo @test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID] @test length(union(sg.ndata.NID)) == length(sg.ndata.NID) end + + @testset "induced_subgraph" begin + # Create a simple GNNGraph with two nodes and one edge + graph = GNNGraph() # Initialize graph + add_nodes!(graph, 2) # Add 2 nodes + add_edge!(graph, 1, 2) # Add an edge from node 1 to node 2 + graph.x = rand(10, 2) # Assign random features to both nodes (10 features per node) + + # Induce subgraph on both nodes + nodes = [1, 2] + subgraph = induced_subgraph(graph, nodes) + + @test num_nodes(subgraph) == 2 # Subgraph should have 2 nodes + @test num_nodes(subgraph) == 1 # Subgraph should have 1 edge + ### TODO @test subgraph.ndata.x == graph.x[:, nodes] # Features should match the original graph + end end \ No newline at end of file