-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
253 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
```@meta | ||
CurrentModule = GraphNeuralNetworks | ||
``` | ||
|
||
# Samplers | ||
|
||
|
||
## Docs | ||
|
||
```@autodocs | ||
Modules = [GraphNeuralNetworks] | ||
Pages = ["samplers.jl"] | ||
Private = false | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
""" | ||
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size]) | ||
A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). | ||
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training | ||
originally introduced in "Inductive Representation Learning on Large Graphs" paper. | ||
[see https://arxiv.org/abs/1706.02216] | ||
# Fields | ||
- `graph::GNNGraph`: The input graph. | ||
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. | ||
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. | ||
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). | ||
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. | ||
# Usage | ||
```jldoctest | ||
julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) | ||
julia> batch_counter = 0 | ||
julia> for mini_batch_gnn in loader | ||
batch_counter += 1 | ||
println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn)) | ||
``` | ||
""" | ||
struct NeighborLoader | ||
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) | ||
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer | ||
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling) | ||
num_layers::Int # Number of layers for neighborhood expansion | ||
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given | ||
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation | ||
end | ||
|
||
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing, | ||
num_layers::Int, batch_size::Union{Int, Nothing}=nothing) | ||
return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers, | ||
batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) | ||
end | ||
|
||
# Function to get cached neighbors or compute them | ||
function get_neighbors(loader::NeighborLoader, node::Int) | ||
if haskey(loader.neighbors_cache, node) | ||
return loader.neighbors_cache[node] | ||
else | ||
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph | ||
loader.neighbors_cache[node] = neighbors | ||
return neighbors | ||
end | ||
end | ||
|
||
# Function to sample neighbors for a given node at a specific layer | ||
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) | ||
neighbors = get_neighbors(loader, node) | ||
if isempty(neighbors) | ||
return Int[] | ||
else | ||
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer | ||
return rand(neighbors, num_samples) # Randomly sample neighbors | ||
end | ||
end | ||
|
||
# Iterator protocol for NeighborLoader with lazy batch loading | ||
function Base.iterate(loader::NeighborLoader, state=1) | ||
if state > length(loader.input_nodes) | ||
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number) | ||
end | ||
|
||
# Determine the size of the current batch | ||
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch | ||
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes | ||
|
||
# Set for tracking the subgraph nodes | ||
subgraph_nodes = Set(batch_nodes) | ||
|
||
for node in batch_nodes | ||
# Initialize current layer of nodes (starting with the node itself) | ||
sampled_neighbors = Set([node]) | ||
|
||
# For each GNN layer, sample the neighborhood | ||
for layer in 1:loader.num_layers | ||
new_neighbors = Set{Int}() | ||
for n in sampled_neighbors | ||
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer | ||
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set | ||
end | ||
sampled_neighbors = new_neighbors | ||
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors | ||
end | ||
end | ||
|
||
# Collect subgraph nodes and their features | ||
subgraph_node_list = collect(subgraph_nodes) | ||
|
||
if isempty(subgraph_node_list) | ||
return GNNGraph(), state + batch_size | ||
end | ||
|
||
mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes | ||
|
||
# Continue iteration for the next batch | ||
return mini_batch_gnn, state + batch_size | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Helper function to create a simple graph with node features using GNNGraph | ||
function create_test_graph() | ||
source = [1, 2, 3, 4] # Define source nodes of edges | ||
target = [2, 3, 4, 5] # Define target nodes of edges | ||
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes) | ||
|
||
return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features | ||
end | ||
|
||
# Tests for NeighborLoader structure and its functionalities | ||
@testset "NeighborLoader tests" begin | ||
|
||
# 1. Basic functionality: Check neighbor sampling and subgraph creation | ||
@testset "Basic functionality" begin | ||
g = create_test_graph() | ||
|
||
# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2 | ||
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph is not empty | ||
@test !isempty(mini_batch_gnn.graph) | ||
|
||
num_sampled_nodes = mini_batch_gnn.num_nodes | ||
println("Number of nodes in mini-batch: ", num_sampled_nodes) | ||
|
||
@test num_sampled_nodes == 2 | ||
|
||
# Test if there are edges in the subgraph | ||
@test mini_batch_gnn.num_edges > 0 | ||
end | ||
|
||
# 2. Edge case: Single node with no neighbors | ||
@testset "Single node with no neighbors" begin | ||
g = SimpleDiGraph(1) # A graph with a single node and no edges | ||
node_features = rand(Float32, 5, 1) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains only one node | ||
@test size(mini_batch_gnn.x, 2) == 1 | ||
end | ||
|
||
# 3. Edge case: A node with no outgoing edges (isolated node) | ||
@testset "Node with no outgoing edges" begin | ||
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges | ||
node_features = rand(Float32, 5, 2) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled) | ||
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes | ||
end | ||
|
||
# 4. Edge case: A fully connected graph | ||
@testset "Fully connected graph" begin | ||
g = SimpleDiGraph(3) | ||
add_edge!(g, 1, 2) | ||
add_edge!(g, 2, 3) | ||
add_edge!(g, 3, 1) | ||
node_features = rand(Float32, 5, 3) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if all nodes are included in the mini-batch since it's fully connected | ||
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included | ||
end | ||
|
||
# 5. Edge case: More layers than the number of neighbors | ||
@testset "More layers than available neighbors" begin | ||
g = SimpleDiGraph(3) | ||
add_edge!(g, 1, 2) | ||
add_edge!(g, 2, 3) | ||
node_features = rand(Float32, 5, 3) | ||
graph = GNNGraph(g, ndata = node_features) | ||
|
||
# Test with 3 layers but only enough connections for 2 layers | ||
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains all available nodes | ||
@test size(mini_batch_gnn.x, 2) == 1 | ||
end | ||
|
||
# 6. Edge case: Large batch size greater than the number of input nodes | ||
@testset "Large batch size" begin | ||
g = create_test_graph() | ||
|
||
# Define NeighborLoader with a larger batch size than input nodes | ||
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph is not empty | ||
@test !isempty(mini_batch_gnn.graph) | ||
|
||
# Test if the correct number of nodes are sampled | ||
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected | ||
end | ||
|
||
# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer | ||
@testset "No neighbors sampled" begin | ||
g = create_test_graph() | ||
|
||
# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2 | ||
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2) | ||
|
||
mini_batch_gnn, next_state = iterate(loader) | ||
|
||
# Test if the mini-batch graph contains only the input nodes | ||
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph | ||
end | ||
|
||
end |