Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added perturb_edges function #423

Merged
merged 14 commits into from
Jul 18, 2024
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, rand_like
import Functors

include("chainrules.jl") # hacks for differentiability
Expand Down Expand Up @@ -78,6 +78,7 @@ export add_nodes,
to_bidirected,
to_unidirected,
random_walk_pe,
perturb_edges,
remove_nodes,
ppr_diffusion,
drop_nodes,
Expand Down
66 changes: 66 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,72 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
ntypes, etypes)
end

"""
perturb_edges([rng], g::GNNGraph, perturb_ratio)
rbSparky marked this conversation as resolved.
Show resolved Hide resolved

Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations.

The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.

# Parameters
- `g::GNNGraph`: The graph to be perturbed.
- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges.
- `seed=123`: An optional seed for the random number generator to ensure reproducible results.

# Examples

```julia
julia> g = GNNGraph((s, t, w))
GNNGraph:
num_nodes: 4
num_edges: 5

julia> perturbed_g = perturb_edges(g, 0.2)
GNNGraph:
num_nodes: 4
num_edges: 6 # One new edge added if the original graph had 5 edges, as 0.2 of 5 is 1.

julia> perturbed_g = perturb_edges(g, 0.5, seed=42)
GNNGraph:
num_nodes: 4
num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2.
```
"""
function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; rng::AbstractRNG = Random.default_rng())
@assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1"

Random.seed!(rng)

num_current_edges = g.num_edges
num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio)

if num_edges_to_add == 0
return g
end

num_nodes = g.num_nodes
@assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges"

snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)
tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)

mask_loops = snew .!= tnew
snew = snew[mask_loops]
tnew = tnew[mask_loops]

while length(snew) < num_edges_to_add
n = num_edges_to_add - length(snew)
snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
mask_new_loops = snewnew .!= tnewnew
snewnew = snewnew[mask_new_loops]
tnewnew = tnewnew[mask_new_loops]
snew = [snew; snewnew]
tnew = [tnew; tnewnew]
end

return add_edges(g, (snew, tnew, nothing))
end


### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable
Expand Down
8 changes: 8 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ end
end
end

@testset "perturb_edges" begin if GRAPH_T == :coo
s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1]
g = GNNGraph((s, t))
rng = MersenneTwister(42)
g_per = perturb_edges(g, 0.5, rng=rng)
@assert g_per.num_edges == 8
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end end

@testset "remove_nodes" begin if GRAPH_T == :coo
#single node
s = [1, 1, 2, 3]
Expand Down
Loading