diff --git a/GraphNeuralNetworks/src/samplers.jl b/GraphNeuralNetworks/src/samplers.jl index 44fb3c825..4157b52fc 100644 --- a/GraphNeuralNetworks/src/samplers.jl +++ b/GraphNeuralNetworks/src/samplers.jl @@ -1,7 +1,25 @@ using GraphNeuralNetworks using Graphs -# Define a NeighborLoader structure for sampling neighbors +""" + struct NeighborLoader + +A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). +It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training. + +# Fields: +- `graph::GNNGraph`: The input graph containing nodes and edges, along with node features (from GraphNeuralNetworks.jl). +- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. +- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. +- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). +- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. +- `neighbors_cache::Dict{Int, Vector{Int}}`: A cache to store sampled neighbors for each node, preventing redundant sampling. + +# Usage: +```julia +loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) +``` +""" struct NeighborLoader graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer @@ -11,11 +29,40 @@ struct NeighborLoader neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation end -# Constructor for NeighborLoader with optional batch size +### `NeighborLoader` constructor +""" + NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) + +Creates a `NeighborLoader` to sample neighbors from the provided `graph` for the training. + The loader supports batching and multi-layer neighbor sampling. + +# Arguments: +- `graph::GNNGraph`: The input graph with node features. +- `num_neighbors::Vector{Int}`: Number of neighbors to sample per node, per layer. +- `input_nodes::Vector{Int}`: Set of starting nodes for sampling. +- `num_layers::Int`: Number of layers to expand neighborhoods for sampling. +- `batch_size::Union{Int, Nothing}`: Optional batch size. If `nothing`, it defaults to the length of `input_nodes`. + +# Returns: +A `NeighborLoader` object. +""" function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) end +""" + get_neighbors(loader::NeighborLoader, node::Int) -> Vector{Int} + +Returns the neighbors of a given `node` in the graph from the `NeighborLoader`. + It first checks if the neighbors are cached; if not, it retrieves the neighbors from the graph and caches them for future use. + +# Arguments: +- `loader::NeighborLoader`: The `NeighborLoader` instance. +- `node::Int`: The node whose neighbors you want to sample. + +# Returns: +A vector of neighbor node indices. +""" # Function to get cached neighbors or compute them function get_neighbors(loader::NeighborLoader, node::Int) if haskey(loader.neighbors_cache, node) @@ -27,6 +74,20 @@ function get_neighbors(loader::NeighborLoader, node::Int) end end +""" + sample_neighbors(loader::NeighborLoader, node::Int, layer::Int) -> Vector{Int} + +Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN. + The number of neighbors sampled is defined in `loader.num_neighbors`. + +# Arguments: +- `loader::NeighborLoader`: The `NeighborLoader` instance. +- `node::Int`: The node to sample neighbors for. +- `layer::Int`: The current GNN layer (used to determine how many neighbors to sample). + +# Returns: +A vector of sampled neighbor node indices. +""" # Function to sample neighbors for a given node at a specific layer function sample_neighbors(loader::NeighborLoader, node::Int, layer::Int) neighbors = get_neighbors(loader, node) @@ -38,6 +99,20 @@ function sample_neighbors(loader::NeighborLoader, node::Int, layer::Int) end end +""" + induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph + +Generates a subgraph from the original graph using the provided `nodes`. + The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph. + If a node has no neighbors, an isolated node will be added to the subgraph. + +# Arguments: +- `graph::GNNGraph`: The original graph containing nodes, edges, and node features. +- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph. + +# Returns: +A new `GNNGraph` containing the subgraph with the specified nodes and their features. +""" function induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) if isempty(nodes) return GNNGraph() # Return empty graph if no nodes are provided @@ -73,6 +148,22 @@ function induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) return GNNGraph(source, target, ndata = new_features) # Return the new GNNGraph with subgraph and features end +""" + Base.iterate(loader::NeighborLoader, state::Int=1) -> Tuple{GNNGraph, Int} + +Implements the iterator protocol for `NeighborLoader`, allowing mini-batch processing for neighbor sampling in GNNs. + Each call to `iterate` returns a mini-batch subgraph with sampled neighbors for a batch of input nodes, + expanding their neighborhoods for a specified number of layers. + +# Arguments: +- `loader::NeighborLoader`: The `NeighborLoader` instance to sample neighbors from. +- `state::Int`: The current position in the input nodes for batching. Defaults to 1. + +# Returns: +A tuple `(mini_batch_gnn, next_state)` where: +- `mini_batch_gnn::GNNGraph`: The subgraph induced by the sampled nodes and their neighbors for the current mini-batch. +- `next_state::Int`: The next state (index) for iterating through the input nodes. If the input nodes are exhausted, returns `nothing`. +""" # Iterator protocol for NeighborLoader with lazy batch loading function Base.iterate(loader::NeighborLoader, state=1) if state > length(loader.input_nodes) @@ -113,4 +204,4 @@ function Base.iterate(loader::NeighborLoader, state=1) # Continue iteration for the next batch return mini_batch_gnn, state + batch_size -end \ No newline at end of file +end