Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ArborX neighbor list generation to work with ArborX 2.0 #803

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 112 additions & 13 deletions core/src/Cabana_Experimental_NeighborList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ namespace ArborX
{
//! Neighbor access trait for Cabana slice and/or Kokkos View.
template <typename Positions>
struct AccessTraits<Positions, PrimitivesTag,
struct AccessTraits<Positions,
#if ARBORX_VERSION < 10799
PrimitivesTag,
#endif
std::enable_if_t<Cabana::is_slice<Positions>{} ||
Kokkos::is_view<Positions>{}>>
{
Expand All @@ -120,17 +123,21 @@ struct AccessTraits<Positions, PrimitivesTag,
return Cabana::size( x );
}
//! Get the particle at the index.
static KOKKOS_FUNCTION Point get( Positions const& x, size_type i )
static KOKKOS_FUNCTION auto get( Positions const& x, size_type i )
{
return { static_cast<float>( x( i, 0 ) ),
static_cast<float>( x( i, 1 ) ),
static_cast<float>( x( i, 2 ) ) };
return Point{ static_cast<float>( x( i, 0 ) ),
static_cast<float>( x( i, 1 ) ),
static_cast<float>( x( i, 2 ) ) };
}
};
//! Neighbor access trait.
template <typename Positions>
struct AccessTraits<
Cabana::Experimental::Impl::SubPositionsAndRadius<Positions>, PredicatesTag>
struct AccessTraits<Cabana::Experimental::Impl::SubPositionsAndRadius<Positions>
#if ARBORX_VERSION < 10799
,
PredicatesTag
#endif
>
{
//! Position wrapper with partial range and radius information.
using PositionLike =
Expand All @@ -148,10 +155,15 @@ struct AccessTraits<
static KOKKOS_FUNCTION auto get( PositionLike const& x, size_type i )
{
assert( i < size( x ) );
auto const point =
AccessTraits<typename PositionLike::positions_type,
PrimitivesTag>::get( x.data, x.first + i );
return attach( intersects( Sphere{ point, x.radius } ), (int)i );
auto const point = AccessTraits<typename PositionLike::positions_type
#if ARBORX_VERSION < 10799
,
PrimitivesTag
#endif
>::get( x.data, x.first + i );
return attach(
intersects( Sphere{ point, static_cast<float>( x.radius ) } ),
(int)i );
}
};
} // namespace ArborX
Expand Down Expand Up @@ -186,6 +198,21 @@ struct CollisionFilter<HalfNeighborTag>
template <typename Tag>
struct NeighborDiscriminatorCallback
{
#if ARBORX_VERSION >= 10799
template <typename Predicate, typename Geometry, typename OutputFunctor>
KOKKOS_FUNCTION void
operator()( Predicate const& predicate,
ArborX::PairValueIndex<Geometry, int> const& value_pair,
OutputFunctor const& out ) const
{
int const primitive_index = value_pair.index;
int const predicate_index = getData( predicate );
if ( CollisionFilter<Tag>::keep( predicate_index, primitive_index ) )
{
out( primitive_index );
}
}
#else
template <typename Predicate, typename OutputFunctor>
KOKKOS_FUNCTION void operator()( Predicate const& predicate,
int primitive_index,
Expand All @@ -197,13 +224,28 @@ struct NeighborDiscriminatorCallback
out( primitive_index );
}
}
#endif
};

// Count in the first pass
template <typename Counts, typename Tag>
struct NeighborDiscriminatorCallback2D_FirstPass
{
Counts counts;
#if ARBORX_VERSION >= 10799
template <typename Predicate, typename Geometry>
KOKKOS_FUNCTION void
operator()( Predicate const& predicate,
ArborX::PairValueIndex<Geometry, int> const& value_pair ) const
{
int const primitive_index = value_pair.index;
int const predicate_index = getData( predicate );
if ( CollisionFilter<Tag>::keep( predicate_index, primitive_index ) )
{
++counts( predicate_index ); // WARNING see below**
}
}
#else
template <typename Predicate>
KOKKOS_FUNCTION void operator()( Predicate const& predicate,
int primitive_index ) const
Expand All @@ -214,6 +256,7 @@ struct NeighborDiscriminatorCallback2D_FirstPass
++counts( predicate_index ); // WARNING see below**
}
}
#endif
};

// Preallocate and attempt fill in the first pass
Expand All @@ -222,6 +265,29 @@ struct NeighborDiscriminatorCallback2D_FirstPass_BufferOptimization
{
Counts counts;
Neighbors neighbors;
#if ARBORX_VERSION >= 10799
template <typename Predicate, typename Geometry>
KOKKOS_FUNCTION void
operator()( Predicate const& predicate,
ArborX::PairValueIndex<Geometry, int> const& value_pair ) const
{
int const primitive_index = value_pair.index;
int const predicate_index = getData( predicate );
auto& count = counts( predicate_index );
if ( CollisionFilter<Tag>::keep( predicate_index, primitive_index ) )
{
if ( count < (int)neighbors.extent( 1 ) )
{
neighbors( predicate_index, count++ ) =
primitive_index; // WARNING see below**
}
else
{
count++;
}
}
}
#else
template <typename Predicate>
KOKKOS_FUNCTION void operator()( Predicate const& predicate,
int primitive_index ) const
Expand All @@ -241,6 +307,7 @@ struct NeighborDiscriminatorCallback2D_FirstPass_BufferOptimization
}
}
}
#endif
};

// Fill in the second pass
Expand All @@ -249,6 +316,23 @@ struct NeighborDiscriminatorCallback2D_SecondPass
{
Counts counts;
Neighbors neighbors;
#if ARBORX_VERSION >= 10799
template <typename Predicate, typename Geometry>
KOKKOS_FUNCTION void
operator()( Predicate const& predicate,
ArborX::PairValueIndex<Geometry, int> const& value_pair ) const
{
int const primitive_index = value_pair.index;
int const predicate_index = getData( predicate );
auto& count = counts( predicate_index );
if ( CollisionFilter<Tag>::keep( predicate_index, primitive_index ) )
{
assert( count < (int)neighbors.extent( 1 ) );
neighbors( predicate_index, count++ ) =
primitive_index; // WARNING see below**
}
}
#else
template <typename Predicate>
KOKKOS_FUNCTION void operator()( Predicate const& predicate,
int primitive_index ) const
Expand All @@ -262,6 +346,7 @@ struct NeighborDiscriminatorCallback2D_SecondPass
primitive_index; // WARNING see below**
}
}
#endif
};

// NOTE** Taking advantage of the knowledge that one predicate is processed by a
Expand Down Expand Up @@ -319,7 +404,12 @@ auto makeNeighborList( ExecutionSpace space, Tag, Positions const& positions,

using memory_space = typename Positions::memory_space;

#if ARBORX_VERSION >= 10799
ArborX::BoundingVolumeHierarchy bvh(
space, ArborX::Experimental::attach_indices<int>( positions ) );
#else
ArborX::BVH<memory_space> bvh( space, positions );
#endif

Kokkos::View<int*, memory_space> indices(
Kokkos::view_alloc( "indices", Kokkos::WithoutInitializing ), 0 );
Expand Down Expand Up @@ -444,14 +534,23 @@ auto make2DNeighborList( ExecutionSpace space, Tag, Positions const& positions,

using memory_space = typename Positions::memory_space;

#if ARBORX_VERSION >= 10799
ArborX::BoundingVolumeHierarchy bvh(
space, ArborX::Experimental::attach_indices<int>( positions ) );
#else
ArborX::BVH<memory_space> bvh( space, positions );
#endif

auto const predicates =
Impl::makePredicates( positions, first, last, radius );

auto const n_queries =
ArborX::AccessTraits<std::remove_const_t<decltype( predicates )>,
ArborX::PredicatesTag>::size( predicates );
ArborX::AccessTraits<std::remove_const_t<decltype( predicates )>
#if ARBORX_VERSION < 10799
,
ArborX::PredicatesTag
#endif
>::size( predicates );

Kokkos::View<int**, memory_space> neighbors;
Kokkos::View<int*, memory_space> counts( "counts", n_queries );
Expand Down
Loading