Skip to content

Commit

Permalink
gpu compat
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Jul 17, 2024
1 parent ff71498 commit e5d098d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 38 deletions.
68 changes: 34 additions & 34 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,42 +443,42 @@ GNNGraph:
num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2.
```
"""
function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio; seed::Int = Random.default_rng())
@assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1"

Random.seed!(seed)

num_current_edges = g.num_edges
num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio)

if num_edges_to_add == 0
return g
end

num_nodes = g.num_nodes
@assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges"

snew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes)
tnew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes)

mask_loops = snew .!= tnew
snew = snew[mask_loops]
tnew = tnew[mask_loops]

while length(snew) < num_edges_to_add
n = num_edges_to_add - length(snew)
snewnew = ceil.(Int, rand(Float32, n) .* num_nodes)
tnewnew = ceil.(Int, rand(Float32, n) .* num_nodes)
mask_new_loops = snewnew .!= tnewnew
snewnew = snewnew[mask_new_loops]
tnewnew = tnewnew[mask_new_loops]
snew = [snew; snewnew]
tnew = [tnew; tnewnew]
function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; rng::AbstractRNG = Random.default_rng())
@assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1"

Random.seed!(rng)

num_current_edges = g.num_edges
num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio)

if num_edges_to_add == 0
return g
end

num_nodes = g.num_nodes
@assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges"

snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)
tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)

mask_loops = snew .!= tnew
snew = snew[mask_loops]
tnew = tnew[mask_loops]

while length(snew) < num_edges_to_add
n = num_edges_to_add - length(snew)
snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
mask_new_loops = snewnew .!= tnewnew
snewnew = snewnew[mask_new_loops]
tnewnew = tnewnew[mask_new_loops]
snew = [snew; snewnew]
tnew = [tnew; tnewnew]
end

return add_edges(g, (snew, tnew, nothing))
end

return add_edges(g, (snew, tnew, nothing))
end


### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable
# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector}
Expand Down
8 changes: 4 additions & 4 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ end
end

@testset "perturb_edges" begin if GRAPH_T == :coo
s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4];
w = Float32[1.0, 2.0, 3.0, 4.0, 5.0];
g = GNNGraph((s, t, w))
g_per = perturb_edges(g, 0.5, seed = 42)
s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1]
g = GNNGraph((s, t))
rng = MersenneTwister(42)
g_per = perturb_edges(g, 0.5, rng=rng)
@assert g_per.num_edges == 8
end end

Expand Down

0 comments on commit e5d098d

Please sign in to comment.