Skip to content

Commit

Permalink
feat: add edata&ndata support for induced_subgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka committed Oct 6, 2024
1 parent 000b670 commit 5041905
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
19 changes: 7 additions & 12 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,23 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
# Collect edges to add
source = Int[]
target = Int[]
backup_gnn = GNNGraph()
eindices = Int[]
for node in nodes
neighbors = Graphs.neighbors(graph, node, dir = :in)
if isempty(neighbors)
backup_gnn = add_nodes(backup_gnn, 1)
end
for neighbor in neighbors
if neighbor in keys(node_map)
push!(target, node_map[node])
push!(source, node_map[neighbor])

eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
push!(eindices, eindex)
end
end
end

# Extract features for the new nodes
#new_features = graph.x[:, nodes]

if isempty(source) && isempty(target)
#backup_gnn.ndata.x = new_features ### TODO fix & add edges data (probably push themto the new vector?)
return backup_gnn # Return empty graph if no nodes are provided
end
new_ndata = getobs(graph.ndata, nodes)
new_edata = getobs(graph.edata, eindices)

return GNNGraph(source, target, num_nodes = length(node_map))
#, ndata = new_features) # Return the new GNNGraph with subgraph and features
return GNNGraph(source, target, num_nodes = length(node_map), ndata = new_ndata, edata = new_edata)
end
32 changes: 20 additions & 12 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,26 @@ if GRAPH_T == :coo
end

@testset "induced_subgraph" begin
# Create a simple GNNGraph with two nodes and one edge
s = [1]
t = [2]
### TODO add data
graph = GNNGraph((s, t))
s = [1, 2]
t = [2, 3]

# Induce subgraph on both nodes
nodes = [1, 2]
subgraph = induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 2 # Subgraph should have 2 nodes
@test subgraph.num_edges == 1 # Subgraph should have 1 edge
### TODO @test subgraph.ndata.x == graph.x[:, nodes] # Features should match the original graph
graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2))

nodes = [1, 2, 3]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 3
@test subgraph.num_edges == 2
@test subgraph.ndata.x == graph.ndata.x
@test subgraph.ndata.y == graph.ndata.y
@test subgraph.edata == graph.edata

graph = GNNGraph(2)
graph = add_edges(graph, ([2], [1]))
nodes = [1]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 1
@test subgraph.num_edges == 0
end
end

0 comments on commit 5041905

Please sign in to comment.