Skip to content

Commit

Permalink
Add nth nearest neighbors and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed May 14, 2024
1 parent 0feb539 commit ddd13b6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/lib/GraphsExtensions/src/neighbors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Graphs: AbstractGraph, neighborhood

function nth_nearest_neighbors(g::AbstractGraph, v, n::Int)
isone(n) && return neighborhood(g, v, 1)
return setdiff(neighborhood(g, v, n), neighborhood(g, v, n - 1))
end

next_nearest_neighbors(g::AbstractGraph, v) = nth_nearest_neighbors(g, v, 2)
23 changes: 23 additions & 0 deletions test/test_graphsextensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@eval module $(gensym())
using NamedGraphs.NamedGraphGenerators: named_grid
using NamedGraphs.GraphsExtensions: next_nearest_neighbors, nth_nearest_neighbors
using Test: @test, @testset

#TODO: Add tests for other graphs extensions
@testset "GraphsExtensions" begin
@testset "Test nth nearest neighbours" begin
L = 10
g = named_grid((L, 1))
vstart = (1, 1)
@test only(nth_nearest_neighbors(g, vstart, L - 1)) == (L, 1)
@test only(next_nearest_neighbors(g, vstart)) == (3, 1)

L = 9
g = named_grid((L, L))
v_middle = (ceil(Int64, L / 2), ceil(Int64, L / 2))
corners = [(L, 1), (1, L), (L, L), (1, 1)]
@test length(next_nearest_neighbors(g, v_middle)) == 8
@test issetequal(nth_nearest_neighbors(g, v_middle, L - 1), corners)
end
end
end

0 comments on commit ddd13b6

Please sign in to comment.