From 847b4f9ba87d7187e0ef35627fb0fc4762c455a8 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 17 Mar 2024 18:45:01 +0530 Subject: [PATCH 01/12] add edge perturbation --- src/GNNGraphs/transform.jl | 59 +++++++++++++++++++++++++++++++++++++ test/GNNGraphs/transform.jl | 7 +++++ 2 files changed, 66 insertions(+) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index ffefacde3..60265509c 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -352,6 +352,65 @@ function add_edges(g::GNNHeteroGraph{<:COO_T}, ntypes, etypes) end +""" + perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; seed::Int=42) + +Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations. + +The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. + +# Parameters +- `g::GNNGraph`: The graph to be perturbed. +- `perturb_ratio::Float64`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges. +- `seed=123`: An optional seed for the random number generator to ensure reproducible results. + +# Examples + +```julia +julia> g = GNNGraph((s, t, w)) +GNNGraph: + num_nodes: 4 + num_edges: 5 + +julia> perturbed_g = perturb_edges(g, 0.2) +GNNGraph: + num_nodes: 4 + num_edges: 6 # One new edge added if the original graph had 5 edges, as 0.2 of 5 is 1. + +julia> perturbed_g = perturb_edges(g, 0.5, seed=42) +GNNGraph: + num_nodes: 4 + num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2. +``` +""" +function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; seed::Int = Random.default_rng()) + @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" + + Random.seed!(seed) + + num_current_edges = g.num_edges + num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio) + + if num_edges_to_add == 0 + return g + end + + num_nodes = g.num_nodes + @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" + + snew = Int[] + tnew = Int[] + while length(snew) < num_edges_to_add + s_candidate = rand(1:num_nodes) + t_candidate = rand(1:num_nodes) + if s_candidate != t_candidate + push!(snew, s_candidate) + push!(tnew, t_candidate) + end + end + + return add_edges(g, (snew, tnew, nothing)) +end ### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index d56ba5d1c..582846f4b 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -149,6 +149,13 @@ end end end +@testset "perturb_edges" begin if GRAPH_T == :coo + s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; + w = Float32[1.0, 2.0, 3.0, 4.0, 5.0]; + g = GNNGraph((s, t, w)) + g_per = perturb_edges(g, 0.5, seed = 42) + @assert g_per.num_edges == 8 +end end @testset "add_nodes" begin if GRAPH_T == :coo g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T) From 4141275a87d5e9908218fb64bdfd61dc2934e906 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 17 Mar 2024 18:57:05 +0530 Subject: [PATCH 02/12] add to gnngraphs --- src/GNNGraphs/GNNGraphs.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 59cc9d9e4..237632c6e 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -77,6 +77,7 @@ export add_nodes, to_bidirected, to_unidirected, random_walk_pe, + perturb_edges, # from Flux batch, unbatch, From 26a4493fc654b33a7cf8666df48132240d689218 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 21 Mar 2024 12:41:53 +0530 Subject: [PATCH 03/12] Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index d9edd4bee..8b6bd8412 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -413,7 +413,7 @@ function add_edges(g::GNNHeteroGraph{<:COO_T}, end """ - perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; seed::Int=42) + perturb_edges(g::GNNGraph, perturb_ratio; [seed]) Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations. From bc3d01824e6c185ac6e1ca2a3dea2abc027e3e4d Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 4 Jul 2024 01:06:56 +0530 Subject: [PATCH 04/12] loop fix --- src/GNNGraphs/transform.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 8b6bd8412..9686f5fcb 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -458,15 +458,22 @@ function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; seed::Int = num_nodes = g.num_nodes @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" - snew = Int[] - tnew = Int[] + snew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) + tnew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) + + mask_loops = snew .!= tnew + snew = snew[mask_loops] + tnew = tnew[mask_loops] + while length(snew) < num_edges_to_add - s_candidate = rand(1:num_nodes) - t_candidate = rand(1:num_nodes) - if s_candidate != t_candidate - push!(snew, s_candidate) - push!(tnew, t_candidate) - end + n = num_edges_to_add - length(snew) + snewnew = ceil.(Int, rand(Float32, n) .* num_nodes) + tnewnew = ceil.(Int, rand(Float32, n) .* num_nodes) + mask_new_loops = snewnew .!= tnewnew + snewnew = snewnew[mask_new_loops] + tnewnew = tnewnew[mask_new_loops] + snew = [snew; snewnew] + tnew = [tnew; tnewnew] end return add_edges(g, (snew, tnew, nothing)) From cf21abc03eccca2fd6b04647608b527633273560 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:20:51 +0530 Subject: [PATCH 05/12] Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 9686f5fcb..b6306b4f4 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -421,7 +421,7 @@ The function returns a new `GNNGraph` instance that shares some of the underlyin # Parameters - `g::GNNGraph`: The graph to be perturbed. -- `perturb_ratio::Float64`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges. +- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges. - `seed=123`: An optional seed for the random number generator to ensure reproducible results. # Examples From c04a832cd695a04ccaf60d7f039706cc92328764 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:22:40 +0530 Subject: [PATCH 06/12] Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index b6306b4f4..76eb43667 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -443,7 +443,7 @@ GNNGraph: num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2. ``` """ -function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; seed::Int = Random.default_rng()) +function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio; seed::Int = Random.default_rng()) @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" Random.seed!(seed) From 0d3017deabd8d6032bafb3f330127e18ef332a51 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:22:50 +0530 Subject: [PATCH 07/12] Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 76eb43667..0b66a29d3 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -413,7 +413,7 @@ function add_edges(g::GNNHeteroGraph{<:COO_T}, end """ - perturb_edges(g::GNNGraph, perturb_ratio; [seed]) + perturb_edges([rng], g::GNNGraph, perturb_ratio) Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations. From c68fd0f75def6c3ab64b00a5fb66d2e380c4a9d4 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:24:24 +0530 Subject: [PATCH 08/12] Update transform.jl --- src/GNNGraphs/transform.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 0b66a29d3..6c26961a9 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -458,8 +458,8 @@ function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio; seed::Int = Random.d num_nodes = g.num_nodes @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" - snew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) - tnew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) + snew = ceil.(eltype(s), rand_like(rng, s, Float32, n) .* num_nodes) + tnew = ceil.(eltype(s), rand_like(rng, s, Float32, n) .* num_nodes) mask_loops = snew .!= tnew snew = snew[mask_loops] From ff71498575ca2aa7dbac1ed5300c3949c7f9a72d Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:40:25 +0530 Subject: [PATCH 09/12] Update transform.jl --- src/GNNGraphs/transform.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 6c26961a9..0b66a29d3 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -458,8 +458,8 @@ function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio; seed::Int = Random.d num_nodes = g.num_nodes @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" - snew = ceil.(eltype(s), rand_like(rng, s, Float32, n) .* num_nodes) - tnew = ceil.(eltype(s), rand_like(rng, s, Float32, n) .* num_nodes) + snew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) + tnew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) mask_loops = snew .!= tnew snew = snew[mask_loops] From e5d098d67374b6ee83a616f4fe3f06f85e9fdc63 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Wed, 17 Jul 2024 17:22:30 +0530 Subject: [PATCH 10/12] gpu compat --- src/GNNGraphs/transform.jl | 68 ++++++++++++++++++------------------- test/GNNGraphs/transform.jl | 8 ++--- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 0b66a29d3..c0c102a03 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -443,42 +443,42 @@ GNNGraph: num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2. ``` """ -function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio; seed::Int = Random.default_rng()) - @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" - - Random.seed!(seed) - - num_current_edges = g.num_edges - num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio) - - if num_edges_to_add == 0 - return g - end - - num_nodes = g.num_nodes - @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" - - snew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) - tnew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) - - mask_loops = snew .!= tnew - snew = snew[mask_loops] - tnew = tnew[mask_loops] - - while length(snew) < num_edges_to_add - n = num_edges_to_add - length(snew) - snewnew = ceil.(Int, rand(Float32, n) .* num_nodes) - tnewnew = ceil.(Int, rand(Float32, n) .* num_nodes) - mask_new_loops = snewnew .!= tnewnew - snewnew = snewnew[mask_new_loops] - tnewnew = tnewnew[mask_new_loops] - snew = [snew; snewnew] - tnew = [tnew; tnewnew] +function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; rng::AbstractRNG = Random.default_rng()) + @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" + + Random.seed!(rng) + + num_current_edges = g.num_edges + num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio) + + if num_edges_to_add == 0 + return g + end + + num_nodes = g.num_nodes + @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" + + snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) + tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) + + mask_loops = snew .!= tnew + snew = snew[mask_loops] + tnew = tnew[mask_loops] + + while length(snew) < num_edges_to_add + n = num_edges_to_add - length(snew) + snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) + tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) + mask_new_loops = snewnew .!= tnewnew + snewnew = snewnew[mask_new_loops] + tnewnew = tnewnew[mask_new_loops] + snew = [snew; snewnew] + tnew = [tnew; tnewnew] + end + + return add_edges(g, (snew, tnew, nothing)) end - return add_edges(g, (snew, tnew, nothing)) -end - ### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable # function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector} diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index ac688015e..1930adaea 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -150,10 +150,10 @@ end end @testset "perturb_edges" begin if GRAPH_T == :coo - s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; - w = Float32[1.0, 2.0, 3.0, 4.0, 5.0]; - g = GNNGraph((s, t, w)) - g_per = perturb_edges(g, 0.5, seed = 42) + s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1] + g = GNNGraph((s, t)) + rng = MersenneTwister(42) + g_per = perturb_edges(g, 0.5, rng=rng) @assert g_per.num_edges == 8 end end From df291ad8117a7b60d793de05c203ac718df5ae8b Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 18 Jul 2024 18:01:01 +0530 Subject: [PATCH 11/12] include package --- src/GNNGraphs/GNNGraphs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 3bf411727..d40489936 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -14,7 +14,7 @@ import KrylovKit using ChainRulesCore using LinearAlgebra, Random, Statistics import MLUtils -using MLUtils: getobs, numobs, ones_like, zeros_like +using MLUtils: getobs, numobs, ones_like, zeros_like, rand_like import Functors include("chainrules.jl") # hacks for differentiability From 42a31a7413aef0bce118c62c5f254e8d3873e24b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 18 Jul 2024 17:49:11 +0200 Subject: [PATCH 12/12] Update test/GNNGraphs/transform.jl --- test/GNNGraphs/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index b2ab14801..70570d155 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -182,7 +182,7 @@ end g = GNNGraph((s, t)) rng = MersenneTwister(42) g_per = perturb_edges(g, 0.5, rng=rng) - @assert g_per.num_edges == 8 + @test g_per.num_edges == 8 end end @testset "remove_nodes" begin if GRAPH_T == :coo