Skip to content

Commit

Permalink
drop_nodes(g, p) -> remove_nodes(g, p)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 24, 2024
1 parent 43d4ab0 commit bdc1604
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 28 deletions.
1 change: 0 additions & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ export add_nodes,
perturb_edges,
remove_nodes,
ppr_diffusion,
drop_nodes,
# from MLUtils
batch,
unbatch,
Expand Down
38 changes: 15 additions & 23 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,35 +307,27 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
end

"""
drop_nodes(g::GNNGraph{<:COO_T}, p)
remove_nodes(g::GNNGraph, p)
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
Returns a new graph obtained by dropping nodes from `g` with independent probabilities `p`.
# Arguments
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
- `p`: The probability of dropping each node. Default value is `0.5`.
# Returns
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
# Examples
# Example
```julia
using GraphNeuralNetworks
# Construct a GNNGraph
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
# Drop nodes with a probability of 0.5
g_new = drop_node(g, 0.5)
println(g_new)
julia> g = GNNGraph([1, 1, 2, 2, 3, 4], [1, 2, 3, 1, 3, 1])
GNNGraph:
num_nodes: 4
num_edges: 6
julia> g_new = remove_nodes(g, 0.5)
GNNGraph:
num_nodes: 2
num_edges: 2
```
"""
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
num_nodes = g.num_nodes
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)

new_g = remove_nodes(g, nodes_to_remove)

return new_g
function remove_nodes(g::GNNGraph, p::AbstractFloat)
nodes_to_remove = filter(_ -> rand() < p, 1:g.num_nodes)
return remove_nodes(g, nodes_to_remove)
end

"""
Expand Down
8 changes: 4 additions & 4 deletions GNNGraphs/test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,20 @@ end end
@test edata_new == edatatest
end end

@testset "drop_nodes" begin
@testset "remove_nodes(g, p)" begin
if GRAPH_T == :coo
Random.seed!(42)
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
g = GNNGraph(s, t, graph_type = GRAPH_T)

gnew = drop_nodes(g, Float32(0.5))
gnew = remove_nodes(g, 0.5)
@test gnew.num_nodes == 3

gnew = drop_nodes(g, Float32(1.0))
gnew = remove_nodes(g, 1.0)
@test gnew.num_nodes == 0

gnew = drop_nodes(g, Float32(0.0))
gnew = remove_nodes(g, 0.0)
@test gnew.num_nodes == 5
end
end
Expand Down

0 comments on commit bdc1604

Please sign in to comment.