Skip to content

Commit

Permalink
fix runtest
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 23, 2024
1 parent 31695a0 commit 092c4fa
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 47 deletions.
2 changes: 1 addition & 1 deletion GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_
end

function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::AbstractVector, num_nodes::Int)
degs = fill!(similar(s, T, num_nodes), 0)
degs = zeros_like(s, T, num_nodes)

if dir [:out, :both]
degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize = (num_nodes,))
Expand Down
3 changes: 1 addition & 2 deletions GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ tests = [
"sampling",
"gnnheterograph",
"temporalsnapshotsgnngraph",
"ext/SimpleWeightedGraphs/SimpleWeightedGraphs"
"ext/SimpleWeightedGraphs"
]

!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
Expand All @@ -49,7 +49,6 @@ for graph_type in (:coo, :dense, :sparse)
# global TEST_GPU = false

@testset "$t" for t in tests
t == "GNNGraphs/sampling" && GRAPH_T != :coo && continue
include("$t.jl")
end
end
90 changes: 46 additions & 44 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,48 @@
@testset "sample_neighbors" begin
# replace = false
dir = :in
nodes = 2:3
g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T)
sg = sample_neighbors(g, nodes; dir)
@test sg.num_nodes == 10
@test sg.num_edges == sum(degree(g, i; dir) for i in nodes)
@test size(sg.edata.EID) == (sg.num_edges,)
@test length(union(sg.edata.EID)) == length(sg.edata.EID)
adjlist = adjacency_list(g; dir)
s, t = edge_index(sg)
@test all(t .∈ Ref(nodes))
for i in nodes
@test sort(neighbors(sg, i; dir)) == sort(neighbors(g, i; dir))
end
if GRAPH_T == :coo
@testset "sample_neighbors" begin
# replace = false
dir = :in
nodes = 2:3
g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T)
sg = sample_neighbors(g, nodes; dir)
@test sg.num_nodes == 10
@test sg.num_edges == sum(degree(g, i; dir) for i in nodes)
@test size(sg.edata.EID) == (sg.num_edges,)
@test length(union(sg.edata.EID)) == length(sg.edata.EID)
adjlist = adjacency_list(g; dir)
s, t = edge_index(sg)
@test all(t .∈ Ref(nodes))
for i in nodes
@test sort(neighbors(sg, i; dir)) == sort(neighbors(g, i; dir))
end

# replace = true
dir = :out
nodes = 2:3
K = 2
g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T)
sg = sample_neighbors(g, nodes, K; dir, replace = true)
@test sg.num_nodes == 10
@test sg.num_edges == sum(K for i in nodes)
@test size(sg.edata.EID) == (sg.num_edges,)
adjlist = adjacency_list(g; dir)
s, t = edge_index(sg)
@test all(s .∈ Ref(nodes))
for i in nodes
@test issubset(neighbors(sg, i; dir), adjlist[i])
end
# replace = true
dir = :out
nodes = 2:3
K = 2
g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T)
sg = sample_neighbors(g, nodes, K; dir, replace = true)
@test sg.num_nodes == 10
@test sg.num_edges == sum(K for i in nodes)
@test size(sg.edata.EID) == (sg.num_edges,)
adjlist = adjacency_list(g; dir)
s, t = edge_index(sg)
@test all(s .∈ Ref(nodes))
for i in nodes
@test issubset(neighbors(sg, i; dir), adjlist[i])
end

# dropnodes = true
dir = :in
nodes = 2:3
g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T)
g = GNNGraph(g, ndata = (x1 = rand(10),), edata = (e1 = rand(40),))
sg = sample_neighbors(g, nodes; dir, dropnodes = true)
@test sg.num_edges == sum(degree(g, i; dir) for i in nodes)
@test size(sg.edata.EID) == (sg.num_edges,)
@test size(sg.ndata.NID) == (sg.num_nodes,)
@test sg.edata.e1 == g.edata.e1[sg.edata.EID]
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end
# dropnodes = true
dir = :in
nodes = 2:3
g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T)
g = GNNGraph(g, ndata = (x1 = rand(10),), edata = (e1 = rand(40),))
sg = sample_neighbors(g, nodes; dir, dropnodes = true)
@test sg.num_edges == sum(degree(g, i; dir) for i in nodes)
@test size(sg.edata.EID) == (sg.num_edges,)
@test size(sg.ndata.NID) == (sg.num_nodes,)
@test sg.edata.e1 == g.edata.e1[sg.edata.EID]
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end
end

0 comments on commit 092c4fa

Please sign in to comment.