diff --git a/GraphNeuralNetworks/src/samplers.jl b/GraphNeuralNetworks/src/samplers.jl index 2d65dc217..71383e539 100644 --- a/GraphNeuralNetworks/src/samplers.jl +++ b/GraphNeuralNetworks/src/samplers.jl @@ -72,7 +72,7 @@ function get_neighbors(loader::NeighborLoader, node::Int) end """ - sample_neighbors(loader::NeighborLoader, node::Int, layer::Int) + sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN. The number of neighbors sampled is defined in `loader.num_neighbors`. @@ -86,7 +86,7 @@ Samples a specified number of neighbors for the given `node` at a particular `la A vector of sampled neighbor node indices. """ # Function to sample neighbors for a given node at a specific layer -function sample_neighbors(loader::NeighborLoader, node::Int, layer::Int) +function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) neighbors = get_neighbors(loader, node) if isempty(neighbors) return Int[] @@ -133,7 +133,7 @@ function Base.iterate(loader::NeighborLoader, state=1) for layer in 1:loader.num_layers new_neighbors = Set{Int}() for n in sampled_neighbors - neighbors = sample_neighbors(loader, n, layer) # Sample neighbors of the node for this layer + neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set end sampled_neighbors = new_neighbors