Skip to content

Commit

Permalink
fix: deduplicate function
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka committed Oct 12, 2024
1 parent 2d7bd0b commit 65aa564
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions GraphNeuralNetworks/src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function get_neighbors(loader::NeighborLoader, node::Int)
end

"""
sample_neighbors(loader::NeighborLoader, node::Int, layer::Int)
sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN.
The number of neighbors sampled is defined in `loader.num_neighbors`.
Expand All @@ -86,7 +86,7 @@ Samples a specified number of neighbors for the given `node` at a particular `la
A vector of sampled neighbor node indices.
"""
# Function to sample neighbors for a given node at a specific layer
function sample_neighbors(loader::NeighborLoader, node::Int, layer::Int)
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
neighbors = get_neighbors(loader, node)
if isempty(neighbors)
return Int[]
Expand Down Expand Up @@ -133,7 +133,7 @@ function Base.iterate(loader::NeighborLoader, state=1)
for layer in 1:loader.num_layers
new_neighbors = Set{Int}()
for n in sampled_neighbors
neighbors = sample_neighbors(loader, n, layer) # Sample neighbors of the node for this layer
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set
end
sampled_neighbors = new_neighbors
Expand Down

0 comments on commit 65aa564

Please sign in to comment.