diff --git a/GNNGraphs/src/sampling.jl b/GNNGraphs/src/sampling.jl index 2f5298317..9f7b1487d 100644 --- a/GNNGraphs/src/sampling.jl +++ b/GNNGraphs/src/sampling.jl @@ -114,7 +114,7 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; graph_indicator, ndata, edata, g.gdata) end - return gnew + return gne end """ diff --git a/GNNGraphs/test/sampling.jl b/GNNGraphs/test/sampling.jl index 16131b04b..bf6e7ff72 100644 --- a/GNNGraphs/test/sampling.jl +++ b/GNNGraphs/test/sampling.jl @@ -64,11 +64,10 @@ if GRAPH_T == :coo nodes = [1, 2] subgraph = Graphs.induced_subgraph(graph, nodes) - @test subgraph.num_nodes == 3 - @test subgraph.num_edges == 2 - @test subgraph.ndata.x == graph.ndata.x - @test subgraph.ndata.y == graph.ndata.y - @test subgraph.edata == graph.edata + @test subgraph.num_nodes == 2 + @test subgraph.num_edges == 1 + @test subgraph.ndata == getobs(g.ndata, [1, 2]) + @test subgraph.edata == getobs(graph.edata, 1) graph = GNNGraph(2) graph = add_edges(graph, ([2], [1]))