Skip to content

Commit

Permalink
Address review naming and return GridHalo instead of pair
Browse files Browse the repository at this point in the history
  • Loading branch information
streeve committed Jan 8, 2021
1 parent 0a67b77 commit f0cd924
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
53 changes: 38 additions & 15 deletions core/src/Cabana_ParticleGridCommunication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ void getMigrateDestinations( const LocalGridType& local_grid,
*/
template <class LocalGridType, class PositionSliceType>
Distributor<typename PositionSliceType::device_type>
gridDistributor( const LocalGridType& local_grid, PositionSliceType& positions )
createGridDistributor( const LocalGridType& local_grid,
PositionSliceType& positions )
{
using device_type = typename PositionSliceType::device_type;

Expand Down Expand Up @@ -293,7 +294,7 @@ void gridMigrate( const LocalGridType& local_grid, ParticleContainer& particles,
return;
}

auto distributor = gridDistributor( local_grid, positions );
auto distributor = createGridDistributor( local_grid, positions );

// Redistribute the particles.
migrate( distributor, particles );
Expand Down Expand Up @@ -355,7 +356,7 @@ void gridMigrate( const LocalGridType& local_grid,
}
}

auto distributor = gridDistributor( local_grid, positions );
auto distributor = createGridDistributor( local_grid, positions );

// Resize as needed.
dst_particles.resize( distributor.totalNumImport() );
Expand Down Expand Up @@ -622,6 +623,24 @@ struct PeriodicShift
}
};

template <class HaloType, class ShiftType>
class GridHalo
{
const HaloType _halo;
const ShiftType _shifts;

public:
GridHalo( const HaloType& halo, const ShiftType& shifts )
: _halo( halo )
, _shifts( shifts )
{
}

HaloType getHalo() const { return _halo; }

ShiftType getShifts() const { return _shifts; }
};

//---------------------------------------------------------------------------//
/*!
\brief Determine which data should be ghosted on another decomposition, using
Expand All @@ -644,11 +663,11 @@ struct PeriodicShift
\param max_export_guess The allocation size for halo export ranks, IDs, and
periodic shifts
\return Pair containing the Halo and PeriodicShift.
\return GridHalo containing Halo and PeriodicShift.
*/
template <class LocalGridType, class PositionSliceType,
std::size_t PositionIndex>
auto gridHalo(
auto createGridHalo(
const LocalGridType& local_grid, const PositionSliceType& positions,
std::integral_constant<std::size_t, PositionIndex>,
const int min_halo_width, const int max_export_guess = 0,
Expand Down Expand Up @@ -691,7 +710,10 @@ auto gridHalo(
// Create the Shifts.
auto periodic_shift = PeriodicShift<device_type, PositionIndex>( shifts );

return std::make_pair( halo, periodic_shift );
// Return Halo and PeriodicShifts together.
GridHalo<Halo<device_type>, PeriodicShift<device_type, PositionIndex>>
grid_halo( halo, periodic_shift );
return grid_halo;
}

//---------------------------------------------------------------------------//
Expand All @@ -714,21 +736,21 @@ auto gridHalo(
\param max_export_guess The allocation size for halo export ranks, IDs, and
periodic shifts.
\return Pair containing the Halo and PeriodicShift.
\return GridHalo containing Halo and PeriodicShift.
*/
template <class LocalGridType, class ParticleContainer,
std::size_t PositionIndex>
auto gridHalo(
auto createGridHalo(
const LocalGridType& local_grid, const ParticleContainer& particles,
std::integral_constant<std::size_t, PositionIndex>,
const int min_halo_width, const int max_export_guess = 0,
typename std::enable_if<is_aosoa<ParticleContainer>::value, int>::type* =
0 )
{
auto positions = slice<PositionIndex>( particles );
return gridHalo( local_grid, positions,
std::integral_constant<std::size_t, PositionIndex>(),
min_halo_width, max_export_guess );
return createGridHalo( local_grid, positions,
std::integral_constant<std::size_t, PositionIndex>(),
min_halo_width, max_export_guess );
}

//---------------------------------------------------------------------------//
Expand All @@ -749,13 +771,14 @@ auto gridHalo(
\param particles The particle AoSoA, containing positions.
*/
template <class HaloType, class PeriodicShiftType, class ParticleContainer>
void gridGather( const HaloType& halo, const PeriodicShiftType& shift,
ParticleContainer& particles )
template <class GridHaloType, class ParticleContainer>
void gridGather( const GridHaloType grid_halo, ParticleContainer& particles )
{
auto halo = grid_halo.getHalo();
auto shifts = grid_halo.getShifts();
particles.resize( halo.numLocal() + halo.numGhost() );

gather( halo, particles, shift );
gather( halo, particles, shifts );
}

// TODO: slice version
Expand Down
17 changes: 9 additions & 8 deletions core/unit_test/tstParticleGridCommunication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void testMigrate( const int halo_width, const int test_halo_width,
data_host.resize( data_dst.size() );
Cabana::deep_copy( data_host, data_dst );
}
// Do the migration with separate slices (need to use gridDistributor
// Do the migration with separate slices (need to use createGridDistributor
// directly since slices can't be resized).
else if ( test_type == 2 )
{
Expand All @@ -153,7 +153,8 @@ void testMigrate( const int halo_width, const int test_halo_width,

if ( force_comm || comm_count > 0 )
{
auto distributor = Cabana::gridDistributor( *local_grid, pos_src );
auto distributor =
Cabana::createGridDistributor( *local_grid, pos_src );
Cabana::AoSoA<DataTypes, TEST_MEMSPACE> data_dst(
"data_dst", distributor.totalNumImport() );
auto pos_dst = Cabana::slice<pos_index>( data_dst );
Expand Down Expand Up @@ -226,7 +227,7 @@ void testMigrate( const int halo_width, const int test_halo_width,
}

//---------------------------------------------------------------------------//
void testHalo( const int halo_width, const int test_type )
void testGather( const int halo_width, const int test_type )
{
// Create the MPI partitions.
Cajita::UniformDimPartitioner partitioner;
Expand Down Expand Up @@ -313,10 +314,10 @@ void testHalo( const int halo_width, const int test_type )
if ( test_type == 0 )
{
// Do the gather with the AoSoA.
auto pair = gridHalo( *local_grid, data_src,
std::integral_constant<std::size_t, pos_index>(),
halo_width );
gridGather( pair.first, pair.second, data_src );
auto grid_halo = Cabana::createGridHalo(
*local_grid, data_src,
std::integral_constant<std::size_t, pos_index>(), halo_width );
gridGather( grid_halo, data_src );

data_host.resize( data_src.size() );
Cabana::deep_copy( data_host, data_src );
Expand Down Expand Up @@ -438,7 +439,7 @@ TEST( TEST_CATEGORY, periodic_test_migrate_slice )
TEST( TEST_CATEGORY, periodic_test_gather_inplace )
{
for ( int i = 1; i < 2; i++ )
testHalo( i, 0 );
testGather( i, 0 );
}
//---------------------------------------------------------------------------//

Expand Down

0 comments on commit f0cd924

Please sign in to comment.