diff --git a/core/src/Cabana_Experimental_NeighborList.hpp b/core/src/Cabana_Experimental_NeighborList.hpp index 25e874b3c..ab0b6e668 100644 --- a/core/src/Cabana_Experimental_NeighborList.hpp +++ b/core/src/Cabana_Experimental_NeighborList.hpp @@ -106,7 +106,10 @@ namespace ArborX { //! Neighbor access trait for Cabana slice and/or Kokkos View. template -struct AccessTraits{} || Kokkos::is_view{}>> { @@ -120,17 +123,21 @@ struct AccessTraits( x( i, 0 ) ), - static_cast( x( i, 1 ) ), - static_cast( x( i, 2 ) ) }; + return Point{ static_cast( x( i, 0 ) ), + static_cast( x( i, 1 ) ), + static_cast( x( i, 2 ) ) }; } }; //! Neighbor access trait. template -struct AccessTraits< - Cabana::Experimental::Impl::SubPositionsAndRadius, PredicatesTag> +struct AccessTraits +#if ARBORX_VERSION < 10799 + , + PredicatesTag +#endif + > { //! Position wrapper with partial range and radius information. using PositionLike = @@ -148,10 +155,15 @@ struct AccessTraits< static KOKKOS_FUNCTION auto get( PositionLike const& x, size_type i ) { assert( i < size( x ) ); - auto const point = - AccessTraits::get( x.data, x.first + i ); - return attach( intersects( Sphere{ point, x.radius } ), (int)i ); + auto const point = AccessTraits::get( x.data, x.first + i ); + return attach( + intersects( Sphere{ point, static_cast( x.radius ) } ), + (int)i ); } }; } // namespace ArborX @@ -186,6 +198,21 @@ struct CollisionFilter template struct NeighborDiscriminatorCallback { +#if ARBORX_VERSION >= 10799 + template + KOKKOS_FUNCTION void + operator()( Predicate const& predicate, + ArborX::PairValueIndex const& value_pair, + OutputFunctor const& out ) const + { + int const primitive_index = value_pair.index; + int const predicate_index = getData( predicate ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) + { + out( primitive_index ); + } + } +#else template KOKKOS_FUNCTION void operator()( Predicate const& predicate, int primitive_index, @@ -197,6 +224,7 @@ struct NeighborDiscriminatorCallback out( primitive_index ); } } +#endif }; // Count in the first pass @@ -204,6 +232,20 @@ template struct NeighborDiscriminatorCallback2D_FirstPass { Counts counts; +#if ARBORX_VERSION >= 10799 + template + KOKKOS_FUNCTION void + operator()( Predicate const& predicate, + ArborX::PairValueIndex const& value_pair ) const + { + int const primitive_index = value_pair.index; + int const predicate_index = getData( predicate ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) + { + ++counts( predicate_index ); // WARNING see below** + } + } +#else template KOKKOS_FUNCTION void operator()( Predicate const& predicate, int primitive_index ) const @@ -214,6 +256,7 @@ struct NeighborDiscriminatorCallback2D_FirstPass ++counts( predicate_index ); // WARNING see below** } } +#endif }; // Preallocate and attempt fill in the first pass @@ -222,6 +265,29 @@ struct NeighborDiscriminatorCallback2D_FirstPass_BufferOptimization { Counts counts; Neighbors neighbors; +#if ARBORX_VERSION >= 10799 + template + KOKKOS_FUNCTION void + operator()( Predicate const& predicate, + ArborX::PairValueIndex const& value_pair ) const + { + int const primitive_index = value_pair.index; + int const predicate_index = getData( predicate ); + auto& count = counts( predicate_index ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) + { + if ( count < (int)neighbors.extent( 1 ) ) + { + neighbors( predicate_index, count++ ) = + primitive_index; // WARNING see below** + } + else + { + count++; + } + } + } +#else template KOKKOS_FUNCTION void operator()( Predicate const& predicate, int primitive_index ) const @@ -241,6 +307,7 @@ struct NeighborDiscriminatorCallback2D_FirstPass_BufferOptimization } } } +#endif }; // Fill in the second pass @@ -249,6 +316,23 @@ struct NeighborDiscriminatorCallback2D_SecondPass { Counts counts; Neighbors neighbors; +#if ARBORX_VERSION >= 10799 + template + KOKKOS_FUNCTION void + operator()( Predicate const& predicate, + ArborX::PairValueIndex const& value_pair ) const + { + int const primitive_index = value_pair.index; + int const predicate_index = getData( predicate ); + auto& count = counts( predicate_index ); + if ( CollisionFilter::keep( predicate_index, primitive_index ) ) + { + assert( count < (int)neighbors.extent( 1 ) ); + neighbors( predicate_index, count++ ) = + primitive_index; // WARNING see below** + } + } +#else template KOKKOS_FUNCTION void operator()( Predicate const& predicate, int primitive_index ) const @@ -262,6 +346,7 @@ struct NeighborDiscriminatorCallback2D_SecondPass primitive_index; // WARNING see below** } } +#endif }; // NOTE** Taking advantage of the knowledge that one predicate is processed by a @@ -319,7 +404,12 @@ auto makeNeighborList( ExecutionSpace space, Tag, Positions const& positions, using memory_space = typename Positions::memory_space; +#if ARBORX_VERSION >= 10799 + ArborX::BoundingVolumeHierarchy bvh( + space, ArborX::Experimental::attach_indices( positions ) ); +#else ArborX::BVH bvh( space, positions ); +#endif Kokkos::View indices( Kokkos::view_alloc( "indices", Kokkos::WithoutInitializing ), 0 ); @@ -444,14 +534,23 @@ auto make2DNeighborList( ExecutionSpace space, Tag, Positions const& positions, using memory_space = typename Positions::memory_space; +#if ARBORX_VERSION >= 10799 + ArborX::BoundingVolumeHierarchy bvh( + space, ArborX::Experimental::attach_indices( positions ) ); +#else ArborX::BVH bvh( space, positions ); +#endif auto const predicates = Impl::makePredicates( positions, first, last, radius ); auto const n_queries = - ArborX::AccessTraits, - ArborX::PredicatesTag>::size( predicates ); + ArborX::AccessTraits +#if ARBORX_VERSION < 10799 + , + ArborX::PredicatesTag +#endif + >::size( predicates ); Kokkos::View neighbors; Kokkos::View counts( "counts", n_queries );