Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added remove_edges function #414

Merged
merged 18 commits into from
Mar 21, 2024
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Manifest.toml
.vscode
LocalPreferences.toml
.DS_Store
/test.jl
try.jl
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export add_nodes,
negative_sample,
rand_edge_split,
remove_self_loops,
remove_edges,
remove_multi_edges,
set_edge_weight,
to_bidirected,
Expand Down
45 changes: 45 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,51 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
g.ndata, g.edata, g.gdata)
end

"""
remove_edges(g::GNNGraph, edges_to_remove::Vector{Int})
rbSparky marked this conversation as resolved.
Show resolved Hide resolved

Remove specified edges from a GNNGraph.

# Arguments
- `g`: The input graph from which edges will be removed.
- `edges_to_remove`: Vector of edge indices to be removed.

# Returns
A new GNNGraph with the specified edges removed.

# Example
```julia
using GraphNeuralNetworks

# Construct a GNNGraph
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])

# Remove the second edge
g_new = remove_edges(g, [2])

println(g_new)
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
```
"""
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove)
s, t = edge_index(g)
w = get_edge_weight(g)
edata = g.edata

mask_to_keep = trues(length(s))

mask_to_keep[edges_to_remove] .= false

s = s[mask_to_keep]
t = t[mask_to_keep]
edata = getobs(edata, mask_to_keep)
w = isnothing(w) ? nothing : getobs(w, mask_to_keep)

GNNGraph((s, t, w),
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, edata, g.gdata)
end

"""
remove_multi_edges(g::GNNGraph; aggr=+)

Expand Down
13 changes: 13 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ end
@test nodemap == 1:(g1.num_nodes)
end

@testset "remove_edges" begin
if GRAPH_T == :coo
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
g = GNNGraph(s, t, graph_type = GRAPH_T)
gnew = remove_edges(g, [1])
new_s, new_t = edge_index(gnew)
@test gnew.num_edges == 3
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
@test new_s == s[2:end]
@test new_t == t[2:end]
end
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
end

@testset "add_edges" begin
if GRAPH_T == :coo
s = [1, 1, 2, 3]
Expand Down
Loading