diff --git a/src/general/semidiscretization.jl b/src/general/semidiscretization.jl index 0ed168cc3..ef3f455ca 100644 --- a/src/general/semidiscretization.jl +++ b/src/general/semidiscretization.jl @@ -101,6 +101,21 @@ function create_neighborhood_search(system, neighbor, ::Val{GridNeighborhoodSear return search end +function create_neighborhood_search(system, neighbor, ::Val{NeighborListNeighborhoodSearch}, + min_corner, max_corner) + radius = compact_support(system, neighbor) + grid_nhs = GridNeighborhoodSearch{ndims(system)}(radius, nparticles(neighbor), + min_corner=min_corner, + max_corner=max_corner) + + search = NeighborListNeighborhoodSearch(grid_nhs, nparticles(system)) + + # Initialize neighborhood search + initialize!(search, initial_coordinates(system), initial_coordinates(neighbor)) + + return search +end + @inline function compact_support(system, neighbor) (; smoothing_kernel, smoothing_length) = system return compact_support(smoothing_kernel, smoothing_length) @@ -343,10 +358,12 @@ function update_nhs(u_ode, semi) # Update NHS for each pair of systems foreach_enumerate(systems) do (system_index, system) foreach_enumerate(systems) do (neighbor_index, neighbor) + u = wrap_u(u_ode, system_index, system, semi) u_neighbor = wrap_u(u_ode, neighbor_index, neighbor, semi) neighborhood_search = neighborhood_searches[system_index][neighbor_index] - update!(neighborhood_search, nhs_coords(system, neighbor, u_neighbor)) + update!(neighborhood_search, current_coordinates(u, system), + nhs_coords(system, neighbor, u_neighbor)) end end end diff --git a/src/neighborhood_search/grid_nhs.jl b/src/neighborhood_search/grid_nhs.jl index 81fc280fc..2b9825eae 100644 --- a/src/neighborhood_search/grid_nhs.jl +++ b/src/neighborhood_search/grid_nhs.jl @@ -134,6 +134,10 @@ function update!(neighborhood_search::GridNeighborhoodSearch{NDIMS}, update!(neighborhood_search, i -> extract_svector(x, Val(NDIMS), i)) end +function update!(neighborhood_search::GridNeighborhoodSearch, x, y) + update!(neighborhood_search, y) +end + # Modify the existing hash table by moving particles into their new cells function update!(neighborhood_search::GridNeighborhoodSearch, coords_fun) (; hashtable, cell_buffer, cell_buffer_indices) = neighborhood_search diff --git a/src/neighborhood_search/neighbor_list_nhs.jl b/src/neighborhood_search/neighbor_list_nhs.jl new file mode 100644 index 000000000..82b772d47 --- /dev/null +++ b/src/neighborhood_search/neighbor_list_nhs.jl @@ -0,0 +1,100 @@ +struct NeighborListNeighborhoodSearch{ELTYPE, NHS, PB} + search_radius :: ELTYPE + periodic_box :: PB + grid_nhs :: NHS + neighbor_lists :: Vector{Int} + neighbor_list_start :: Vector{Int} + + function NeighborListNeighborhoodSearch(grid_nhs, n_particles) + (; search_radius, periodic_box) = grid_nhs + + neighbor_lists = Int[] + neighbor_list_start = zeros(Int, n_particles + 1) + + new{typeof(search_radius), + typeof(grid_nhs), typeof(periodic_box)}(search_radius, periodic_box, + grid_nhs, neighbor_lists, + neighbor_list_start) + end +end + +@inline function Base.ndims(neighborhood_search::NeighborListNeighborhoodSearch) + return ndims(neighborhood_search.grid_nhs) +end + +function initialize!(search::NeighborListNeighborhoodSearch, coords, neighbor_coords) + initialize!(search.grid_nhs, neighbor_coords) + + build_neighbor_lists!(search, coords, neighbor_coords) +end + +function update!(search::NeighborListNeighborhoodSearch, coords, neighbor_coords) + update!(search.grid_nhs, neighbor_coords) + + build_neighbor_lists!(search, coords, neighbor_coords) +end + +@inline function eachneighbor(particle, search::NeighborListNeighborhoodSearch) + (; neighbor_lists, neighbor_list_start) = search + + return (neighbor_lists[i] for i in neighbor_list_start[particle]:(neighbor_list_start[particle + 1] - 1)) +end + +function build_neighbor_lists!(search::NeighborListNeighborhoodSearch, coords, + neighbor_coords) + (; grid_nhs, neighbor_lists, neighbor_list_start, search_radius, periodic_box) = search + + resize!(neighbor_lists, 0) + + for particle in 1:(length(neighbor_list_start) - 1) + neighbor_list_start[particle] = length(neighbor_lists) + 1 + + particle_coords = extract_svector(coords, Val(ndims(search)), particle) + + for neighbor in eachneighbor(particle_coords, grid_nhs) + # neighbor_particle_coords = extract_svector(neighbor_coords, + # Val(ndims(search)), neighbor) + + # pos_diff = particle_coords - neighbor_particle_coords + # distance2 = dot(pos_diff, pos_diff) + + # pos_diff, distance2 = compute_periodic_distance(pos_diff, distance2, + # search_radius, periodic_box) + + # if distance2 <= search_radius^2 + append!(neighbor_lists, neighbor) + # end + end + end + + neighbor_list_start[end] = length(neighbor_lists) + 1 + + return search +end + +@inline function for_particle_neighbor_inner(f, system_coords, neighbor_system_coords, + neighborhood_search::NeighborListNeighborhoodSearch, + particle) + (; search_radius, periodic_box) = neighborhood_search + + for neighbor in eachneighbor(particle, neighborhood_search) + particle_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), + particle) + neighbor_coords = extract_svector(neighbor_system_coords, + Val(ndims(neighborhood_search)), neighbor) + + pos_diff = particle_coords - neighbor_coords + distance2 = dot(pos_diff, pos_diff) + + pos_diff, distance2 = compute_periodic_distance(pos_diff, distance2, search_radius, + periodic_box) + + if distance2 <= search_radius^2 + distance = sqrt(distance2) + + # Inline to avoid loss of performance + # compared to not using `for_particle_neighbor`. + @inline f(particle, neighbor, pos_diff, distance) + end + end +end diff --git a/src/neighborhood_search/neighborhood_search.jl b/src/neighborhood_search/neighborhood_search.jl index 0b378454e..623e6dff9 100644 --- a/src/neighborhood_search/neighborhood_search.jl +++ b/src/neighborhood_search/neighborhood_search.jl @@ -108,3 +108,4 @@ end include("trivial_nhs.jl") include("grid_nhs.jl") +include("neighbor_list_nhs.jl")