diff --git a/core/src/Cabana_ParticleGridCommunication.hpp b/core/src/Cabana_ParticleGridCommunication.hpp index d2f2e77ea..5a55df3e5 100644 --- a/core/src/Cabana_ParticleGridCommunication.hpp +++ b/core/src/Cabana_ParticleGridCommunication.hpp @@ -219,7 +219,8 @@ void getMigrateDestinations( const LocalGridType& local_grid, */ template Distributor -gridDistributor( const LocalGridType& local_grid, PositionSliceType& positions ) +createGridDistributor( const LocalGridType& local_grid, + PositionSliceType& positions ) { using device_type = typename PositionSliceType::device_type; @@ -293,7 +294,7 @@ void gridMigrate( const LocalGridType& local_grid, ParticleContainer& particles, return; } - auto distributor = gridDistributor( local_grid, positions ); + auto distributor = createGridDistributor( local_grid, positions ); // Redistribute the particles. migrate( distributor, particles ); @@ -355,7 +356,7 @@ void gridMigrate( const LocalGridType& local_grid, } } - auto distributor = gridDistributor( local_grid, positions ); + auto distributor = createGridDistributor( local_grid, positions ); // Resize as needed. dst_particles.resize( distributor.totalNumImport() ); @@ -622,6 +623,24 @@ struct PeriodicShift } }; +template +class GridHalo +{ + const HaloType _halo; + const ShiftType _shifts; + + public: + GridHalo( const HaloType& halo, const ShiftType& shifts ) + : _halo( halo ) + , _shifts( shifts ) + { + } + + HaloType getHalo() const { return _halo; } + + ShiftType getShifts() const { return _shifts; } +}; + //---------------------------------------------------------------------------// /*! \brief Determine which data should be ghosted on another decomposition, using @@ -644,11 +663,11 @@ struct PeriodicShift \param max_export_guess The allocation size for halo export ranks, IDs, and periodic shifts - \return Pair containing the Halo and PeriodicShift. + \return GridHalo containing Halo and PeriodicShift. */ template -auto gridHalo( +auto createGridHalo( const LocalGridType& local_grid, const PositionSliceType& positions, std::integral_constant, const int min_halo_width, const int max_export_guess = 0, @@ -691,7 +710,10 @@ auto gridHalo( // Create the Shifts. auto periodic_shift = PeriodicShift( shifts ); - return std::make_pair( halo, periodic_shift ); + // Return Halo and PeriodicShifts together. + GridHalo, PeriodicShift> + grid_halo( halo, periodic_shift ); + return grid_halo; } //---------------------------------------------------------------------------// @@ -714,11 +736,11 @@ auto gridHalo( \param max_export_guess The allocation size for halo export ranks, IDs, and periodic shifts. - \return Pair containing the Halo and PeriodicShift. + \return GridHalo containing Halo and PeriodicShift. */ template -auto gridHalo( +auto createGridHalo( const LocalGridType& local_grid, const ParticleContainer& particles, std::integral_constant, const int min_halo_width, const int max_export_guess = 0, @@ -726,9 +748,9 @@ auto gridHalo( 0 ) { auto positions = slice( particles ); - return gridHalo( local_grid, positions, - std::integral_constant(), - min_halo_width, max_export_guess ); + return createGridHalo( local_grid, positions, + std::integral_constant(), + min_halo_width, max_export_guess ); } //---------------------------------------------------------------------------// @@ -749,13 +771,14 @@ auto gridHalo( \param particles The particle AoSoA, containing positions. */ -template -void gridGather( const HaloType& halo, const PeriodicShiftType& shift, - ParticleContainer& particles ) +template +void gridGather( const GridHaloType grid_halo, ParticleContainer& particles ) { + auto halo = grid_halo.getHalo(); + auto shifts = grid_halo.getShifts(); particles.resize( halo.numLocal() + halo.numGhost() ); - gather( halo, particles, shift ); + gather( halo, particles, shifts ); } // TODO: slice version diff --git a/core/unit_test/tstParticleGridCommunication.hpp b/core/unit_test/tstParticleGridCommunication.hpp index bd45cced0..aed6bd0d0 100644 --- a/core/unit_test/tstParticleGridCommunication.hpp +++ b/core/unit_test/tstParticleGridCommunication.hpp @@ -139,7 +139,7 @@ void testMigrate( const int halo_width, const int test_halo_width, data_host.resize( data_dst.size() ); Cabana::deep_copy( data_host, data_dst ); } - // Do the migration with separate slices (need to use gridDistributor + // Do the migration with separate slices (need to use createGridDistributor // directly since slices can't be resized). else if ( test_type == 2 ) { @@ -153,7 +153,8 @@ void testMigrate( const int halo_width, const int test_halo_width, if ( force_comm || comm_count > 0 ) { - auto distributor = Cabana::gridDistributor( *local_grid, pos_src ); + auto distributor = + Cabana::createGridDistributor( *local_grid, pos_src ); Cabana::AoSoA data_dst( "data_dst", distributor.totalNumImport() ); auto pos_dst = Cabana::slice( data_dst ); @@ -226,7 +227,7 @@ void testMigrate( const int halo_width, const int test_halo_width, } //---------------------------------------------------------------------------// -void testHalo( const int halo_width, const int test_type ) +void testGather( const int halo_width, const int test_type ) { // Create the MPI partitions. Cajita::UniformDimPartitioner partitioner; @@ -313,10 +314,10 @@ void testHalo( const int halo_width, const int test_type ) if ( test_type == 0 ) { // Do the gather with the AoSoA. - auto pair = gridHalo( *local_grid, data_src, - std::integral_constant(), - halo_width ); - gridGather( pair.first, pair.second, data_src ); + auto grid_halo = Cabana::createGridHalo( + *local_grid, data_src, + std::integral_constant(), halo_width ); + gridGather( grid_halo, data_src ); data_host.resize( data_src.size() ); Cabana::deep_copy( data_host, data_src ); @@ -438,7 +439,7 @@ TEST( TEST_CATEGORY, periodic_test_migrate_slice ) TEST( TEST_CATEGORY, periodic_test_gather_inplace ) { for ( int i = 1; i < 2; i++ ) - testHalo( i, 0 ); + testGather( i, 0 ); } //---------------------------------------------------------------------------//