Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of Concept: benchmark neighborhood search overhead #284

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/general/semidiscretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ function create_neighborhood_search(system, neighbor, ::Val{GridNeighborhoodSear
return search
end

function create_neighborhood_search(system, neighbor, ::Val{NeighborListNeighborhoodSearch},
min_corner, max_corner)
radius = compact_support(system, neighbor)
grid_nhs = GridNeighborhoodSearch{ndims(system)}(radius, nparticles(neighbor),
min_corner=min_corner,
max_corner=max_corner)

search = NeighborListNeighborhoodSearch(grid_nhs, nparticles(system))

# Initialize neighborhood search
initialize!(search, initial_coordinates(system), initial_coordinates(neighbor))

return search
end

@inline function compact_support(system, neighbor)
(; smoothing_kernel, smoothing_length) = system
return compact_support(smoothing_kernel, smoothing_length)
Expand Down Expand Up @@ -343,10 +358,12 @@ function update_nhs(u_ode, semi)
# Update NHS for each pair of systems
foreach_enumerate(systems) do (system_index, system)
foreach_enumerate(systems) do (neighbor_index, neighbor)
u = wrap_u(u_ode, system_index, system, semi)
u_neighbor = wrap_u(u_ode, neighbor_index, neighbor, semi)
neighborhood_search = neighborhood_searches[system_index][neighbor_index]

update!(neighborhood_search, nhs_coords(system, neighbor, u_neighbor))
update!(neighborhood_search, current_coordinates(u, system),
nhs_coords(system, neighbor, u_neighbor))
end
end
end
Expand Down
4 changes: 4 additions & 0 deletions src/neighborhood_search/grid_nhs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ function update!(neighborhood_search::GridNeighborhoodSearch{NDIMS},
update!(neighborhood_search, i -> extract_svector(x, Val(NDIMS), i))
end

function update!(neighborhood_search::GridNeighborhoodSearch, x, y)
update!(neighborhood_search, y)
end

# Modify the existing hash table by moving particles into their new cells
function update!(neighborhood_search::GridNeighborhoodSearch, coords_fun)
(; hashtable, cell_buffer, cell_buffer_indices) = neighborhood_search
Expand Down
100 changes: 100 additions & 0 deletions src/neighborhood_search/neighbor_list_nhs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
struct NeighborListNeighborhoodSearch{ELTYPE, NHS, PB}
search_radius :: ELTYPE
periodic_box :: PB
grid_nhs :: NHS
neighbor_lists :: Vector{Int}
neighbor_list_start :: Vector{Int}

function NeighborListNeighborhoodSearch(grid_nhs, n_particles)
(; search_radius, periodic_box) = grid_nhs

neighbor_lists = Int[]
neighbor_list_start = zeros(Int, n_particles + 1)

new{typeof(search_radius),
typeof(grid_nhs), typeof(periodic_box)}(search_radius, periodic_box,
grid_nhs, neighbor_lists,
neighbor_list_start)
end
end

@inline function Base.ndims(neighborhood_search::NeighborListNeighborhoodSearch)
return ndims(neighborhood_search.grid_nhs)
end

function initialize!(search::NeighborListNeighborhoodSearch, coords, neighbor_coords)
initialize!(search.grid_nhs, neighbor_coords)

build_neighbor_lists!(search, coords, neighbor_coords)
end

function update!(search::NeighborListNeighborhoodSearch, coords, neighbor_coords)
update!(search.grid_nhs, neighbor_coords)

build_neighbor_lists!(search, coords, neighbor_coords)
end

@inline function eachneighbor(particle, search::NeighborListNeighborhoodSearch)
(; neighbor_lists, neighbor_list_start) = search

return (neighbor_lists[i] for i in neighbor_list_start[particle]:(neighbor_list_start[particle + 1] - 1))
end

function build_neighbor_lists!(search::NeighborListNeighborhoodSearch, coords,
neighbor_coords)
(; grid_nhs, neighbor_lists, neighbor_list_start, search_radius, periodic_box) = search

resize!(neighbor_lists, 0)

for particle in 1:(length(neighbor_list_start) - 1)
neighbor_list_start[particle] = length(neighbor_lists) + 1

particle_coords = extract_svector(coords, Val(ndims(search)), particle)

for neighbor in eachneighbor(particle_coords, grid_nhs)
# neighbor_particle_coords = extract_svector(neighbor_coords,
# Val(ndims(search)), neighbor)

# pos_diff = particle_coords - neighbor_particle_coords
# distance2 = dot(pos_diff, pos_diff)

# pos_diff, distance2 = compute_periodic_distance(pos_diff, distance2,
# search_radius, periodic_box)

# if distance2 <= search_radius^2
append!(neighbor_lists, neighbor)
# end
end
end

neighbor_list_start[end] = length(neighbor_lists) + 1

return search
end

@inline function for_particle_neighbor_inner(f, system_coords, neighbor_system_coords,
neighborhood_search::NeighborListNeighborhoodSearch,
particle)
(; search_radius, periodic_box) = neighborhood_search

for neighbor in eachneighbor(particle, neighborhood_search)
particle_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)),
particle)
neighbor_coords = extract_svector(neighbor_system_coords,
Val(ndims(neighborhood_search)), neighbor)

pos_diff = particle_coords - neighbor_coords
distance2 = dot(pos_diff, pos_diff)

pos_diff, distance2 = compute_periodic_distance(pos_diff, distance2, search_radius,
periodic_box)

if distance2 <= search_radius^2
distance = sqrt(distance2)

# Inline to avoid loss of performance
# compared to not using `for_particle_neighbor`.
@inline f(particle, neighbor, pos_diff, distance)
end
end
end
1 change: 1 addition & 0 deletions src/neighborhood_search/neighborhood_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ end

include("trivial_nhs.jl")
include("grid_nhs.jl")
include("neighbor_list_nhs.jl")