Skip to content

Commit

Permalink
fix perturb_edges
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 23, 2024
1 parent 22e55c7 commit ca83092
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export add_nodes,
to_bidirected,
to_unidirected,
random_walk_pe,
pertub_edges,
remove_nodes,
ppr_diffusion,
drop_nodes,
Expand Down
68 changes: 68 additions & 0 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,74 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
end


"""
perturb_edges([rng], g::GNNGraph, perturb_ratio)
Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`.
The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph.
These new edges are added without creating self-loops.
Optionally, a random `seed` can be provided to ensure reproducible perturbations.
The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges.
The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.
# Arguments
- `g::GNNGraph`: The graph to be perturbed.
- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges.
- `rng`: An optionalrandom number generator to ensure reproducible results.
# Examples
```julia
julia> g = GNNGraph((s, t, w))
GNNGraph:
num_nodes: 4
num_edges: 5
julia> perturbed_g = perturb_edges(g, 0.2)
GNNGraph:
num_nodes: 4
num_edges: 6
```
"""
perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::AbstractFloat) =
pertub_edges(Random.default_rng(), g, perturb_ratio)

function perturb_edges(rng::AbstractRNG, g::GNNGraph{<:COO_T}, perturb_ratio::AbstractFloat)
@assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1"

num_current_edges = g.num_edges
num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio)

if num_edges_to_add == 0
return g
end

num_nodes = g.num_nodes
@assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges"

snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)
tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)

mask_loops = snew .!= tnew
snew = snew[mask_loops]
tnew = tnew[mask_loops]

while length(snew) < num_edges_to_add
n = num_edges_to_add - length(snew)
snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
mask_new_loops = snewnew .!= tnewnew
snewnew = snewnew[mask_new_loops]
tnewnew = tnewnew[mask_new_loops]
snew = [snew; snewnew]
tnew = [tnew; tnewnew]
end

return add_edges(g, (snew, tnew, nothing))
end


### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable
# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector}
Expand Down

0 comments on commit ca83092

Please sign in to comment.