diff --git a/GNNGraphs/test/sampling.jl b/GNNGraphs/test/sampling.jl index c9c12634c..8a1044b27 100644 --- a/GNNGraphs/test/sampling.jl +++ b/GNNGraphs/test/sampling.jl @@ -61,6 +61,14 @@ if GRAPH_T == :coo @test subgraph.ndata.y == graph.ndata.y @test subgraph.edata == graph.edata + nodes = [1, 2] + subgraph = Graphs.induced_subgraph(graph, nodes) + + @test subgraph.num_nodes == 2 + @test subgraph.num_edges == 1 + @test subgraph.ndata == getobs(graph.ndata, [1, 2]) + @test subgraph.edata.e == getobs(graph.edata, 1) + graph = GNNGraph(2) graph = add_edges(graph, ([2], [1])) nodes = [1]