Skip to content

Commit

Permalink
feat: add NeighborLoader (#497)
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka authored Nov 3, 2024
1 parent 9d9e8d0 commit 6b8a7fb
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 2 deletions.
5 changes: 3 additions & 2 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})

node_map = Dict(node => i for (i, node) in enumerate(nodes))

edge_list = [collect(t) for t in zip(edge_index(graph)[1],edge_index(graph)[2])]

# Collect edges to add
source = Int[]
target = Int[]
Expand All @@ -187,8 +189,7 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if neighbor in keys(node_map)
push!(target, node_map[node])
push!(source, node_map[neighbor])

eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
eindex = findfirst(x -> x == [neighbor, node], edge_list)
push!(eindices, eindex)
end
end
Expand Down
2 changes: 2 additions & 0 deletions GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -28,6 +29,7 @@ CUDA = "4, 5"
ChainRulesCore = "1"
Flux = "0.14"
Functors = "0.4.1"
Graphs = "1.12"
GNNGraphs = "1.0"
GNNlib = "0.2"
LinearAlgebra = "1"
Expand Down
1 change: 1 addition & 0 deletions GraphNeuralNetworks/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ makedocs(;
"Message Passing" => "api/messagepassing.md",
"Heterogeneous Graphs" => "api/heterograph.md",
"Temporal Graphs" => "api/temporalgraph.md",
"Samplers" => "api/samplers.md",
"Utils" => "api/utils.md",
],
"Developer Notes" => "dev.md",
Expand Down
14 changes: 14 additions & 0 deletions GraphNeuralNetworks/docs/src/api/samplers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
```@meta
CurrentModule = GraphNeuralNetworks
```

# Samplers


## Docs

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["samplers.jl"]
Private = false
```
4 changes: 4 additions & 0 deletions GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using MLUtils: zeros_like
using Graphs: Graphs

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
Expand Down Expand Up @@ -66,4 +67,7 @@ export GlobalPool,

include("deprecations.jl")

include("samplers.jl")
export NeighborLoader

end
103 changes: 103 additions & 0 deletions GraphNeuralNetworks/src/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size])
A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs).
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training
originally introduced in "Inductive Representation Learning on Large Graphs" paper.
[see https://arxiv.org/abs/1706.02216]
# Fields
- `graph::GNNGraph`: The input graph.
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer.
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling.
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors).
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`.
# Usage
```jldoctest
julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2)
julia> batch_counter = 0
julia> for mini_batch_gnn in loader
batch_counter += 1
println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn))
```
"""
struct NeighborLoader
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl)
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling)
num_layers::Int # Number of layers for neighborhood expansion
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
end

function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing,
num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers,
batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
end

# Function to get cached neighbors or compute them
function get_neighbors(loader::NeighborLoader, node::Int)
if haskey(loader.neighbors_cache, node)
return loader.neighbors_cache[node]
else
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph
loader.neighbors_cache[node] = neighbors
return neighbors
end
end

# Function to sample neighbors for a given node at a specific layer
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
neighbors = get_neighbors(loader, node)
if isempty(neighbors)
return Int[]
else
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer
return rand(neighbors, num_samples) # Randomly sample neighbors
end
end

# Iterator protocol for NeighborLoader with lazy batch loading
function Base.iterate(loader::NeighborLoader, state=1)
if state > length(loader.input_nodes)
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number)
end

# Determine the size of the current batch
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes

# Set for tracking the subgraph nodes
subgraph_nodes = Set(batch_nodes)

for node in batch_nodes
# Initialize current layer of nodes (starting with the node itself)
sampled_neighbors = Set([node])

# For each GNN layer, sample the neighborhood
for layer in 1:loader.num_layers
new_neighbors = Set{Int}()
for n in sampled_neighbors
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set
end
sampled_neighbors = new_neighbors
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors
end
end

# Collect subgraph nodes and their features
subgraph_node_list = collect(subgraph_nodes)

if isempty(subgraph_node_list)
return GNNGraph(), state + batch_size
end

mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes

# Continue iteration for the next batch
return mini_batch_gnn, state + batch_size
end
1 change: 1 addition & 0 deletions GraphNeuralNetworks/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tests = [
"layers/temporalconv",
"layers/pool",
"examples/node_classification_cora",
"samplers"
]

!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
Expand Down
125 changes: 125 additions & 0 deletions GraphNeuralNetworks/test/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Helper function to create a simple graph with node features using GNNGraph
function create_test_graph()
source = [1, 2, 3, 4] # Define source nodes of edges
target = [2, 3, 4, 5] # Define target nodes of edges
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes)

return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features
end

# Tests for NeighborLoader structure and its functionalities
@testset "NeighborLoader tests" begin

# 1. Basic functionality: Check neighbor sampling and subgraph creation
@testset "Basic functionality" begin
g = create_test_graph()

# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph is not empty
@test !isempty(mini_batch_gnn.graph)

num_sampled_nodes = mini_batch_gnn.num_nodes
println("Number of nodes in mini-batch: ", num_sampled_nodes)

@test num_sampled_nodes == 2

# Test if there are edges in the subgraph
@test mini_batch_gnn.num_edges > 0
end

# 2. Edge case: Single node with no neighbors
@testset "Single node with no neighbors" begin
g = SimpleDiGraph(1) # A graph with a single node and no edges
node_features = rand(Float32, 5, 1)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains only one node
@test size(mini_batch_gnn.x, 2) == 1
end

# 3. Edge case: A node with no outgoing edges (isolated node)
@testset "Node with no outgoing edges" begin
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges
node_features = rand(Float32, 5, 2)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes
end

# 4. Edge case: A fully connected graph
@testset "Fully connected graph" begin
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 1)
node_features = rand(Float32, 5, 3)
graph = GNNGraph(g, ndata = node_features)

loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if all nodes are included in the mini-batch since it's fully connected
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included
end

# 5. Edge case: More layers than the number of neighbors
@testset "More layers than available neighbors" begin
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
node_features = rand(Float32, 5, 3)
graph = GNNGraph(g, ndata = node_features)

# Test with 3 layers but only enough connections for 2 layers
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains all available nodes
@test size(mini_batch_gnn.x, 2) == 1
end

# 6. Edge case: Large batch size greater than the number of input nodes
@testset "Large batch size" begin
g = create_test_graph()

# Define NeighborLoader with a larger batch size than input nodes
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph is not empty
@test !isempty(mini_batch_gnn.graph)

# Test if the correct number of nodes are sampled
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected
end

# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
@testset "No neighbors sampled" begin
g = create_test_graph()

# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2)

mini_batch_gnn, next_state = iterate(loader)

# Test if the mini-batch graph contains only the input nodes
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
end

end

0 comments on commit 6b8a7fb

Please sign in to comment.