From bdc160409ca9c7fbc1fb281b9f3361a473cfd180 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 24 Jul 2024 14:14:17 +0200 Subject: [PATCH] drop_nodes(g, p) -> remove_nodes(g, p) --- GNNGraphs/src/GNNGraphs.jl | 1 - GNNGraphs/src/transform.jl | 38 +++++++++++++++---------------------- GNNGraphs/test/transform.jl | 8 ++++---- 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index 80a764a78..c0dbf7678 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -80,7 +80,6 @@ export add_nodes, perturb_edges, remove_nodes, ppr_diffusion, - drop_nodes, # from MLUtils batch, unbatch, diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 9520d63ad..8df726752 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -307,35 +307,27 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector) end """ - drop_nodes(g::GNNGraph{<:COO_T}, p) + remove_nodes(g::GNNGraph, p) -Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability. -Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf). +Returns a new graph obtained by dropping nodes from `g` with independent probabilities `p`. -# Arguments -- `g`: The input graph from which nodes (and their associated edges) will be dropped. -- `p`: The probability of dropping each node. Default value is `0.5`. - -# Returns -A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability. +# Examples -# Example ```julia -using GraphNeuralNetworks -# Construct a GNNGraph -g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3) -# Drop nodes with a probability of 0.5 -g_new = drop_node(g, 0.5) -println(g_new) +julia> g = GNNGraph([1, 1, 2, 2, 3, 4], [1, 2, 3, 1, 3, 1]) +GNNGraph: + num_nodes: 4 + num_edges: 6 + +julia> g_new = remove_nodes(g, 0.5) +GNNGraph: + num_nodes: 2 + num_edges: 2 ``` """ -function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5) - num_nodes = g.num_nodes - nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes) - - new_g = remove_nodes(g, nodes_to_remove) - - return new_g +function remove_nodes(g::GNNGraph, p::AbstractFloat) + nodes_to_remove = filter(_ -> rand() < p, 1:g.num_nodes) + return remove_nodes(g, nodes_to_remove) end """ diff --git a/GNNGraphs/test/transform.jl b/GNNGraphs/test/transform.jl index 23d92a7de..993ac714a 100644 --- a/GNNGraphs/test/transform.jl +++ b/GNNGraphs/test/transform.jl @@ -247,20 +247,20 @@ end end @test edata_new == edatatest end end -@testset "drop_nodes" begin +@testset "remove_nodes(g, p)" begin if GRAPH_T == :coo Random.seed!(42) s = [1, 1, 2, 3] t = [2, 3, 4, 5] g = GNNGraph(s, t, graph_type = GRAPH_T) - gnew = drop_nodes(g, Float32(0.5)) + gnew = remove_nodes(g, 0.5) @test gnew.num_nodes == 3 - gnew = drop_nodes(g, Float32(1.0)) + gnew = remove_nodes(g, 1.0) @test gnew.num_nodes == 0 - gnew = drop_nodes(g, Float32(0.0)) + gnew = remove_nodes(g, 0.0) @test gnew.num_nodes == 5 end end