Skip to content

Commit

Permalink
Merge branch 'develop' into feature/pack_vector_fields
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax authored Aug 21, 2024
2 parents 43954ef + c43d321 commit 2d3dd0d
Showing 1 changed file with 57 additions and 22 deletions.
79 changes: 57 additions & 22 deletions src/tests/parallel/test_haloexchange.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct validate {
};

struct Fixture {
Fixture(bool on_device = false): on_device_(on_device) {
Fixture(bool _on_device = false): on_device(_on_device) {
int nnodes_c[] = {5, 6, 7};
nb_nodes = vec(nnodes_c);
N = nb_nodes[mpi::comm().rank()];
Expand Down Expand Up @@ -147,7 +147,7 @@ struct Fixture {
std::vector<POD> gidx;

int N;
bool on_device_;
bool on_device;
};

//-----------------------------------------------------------------------------
Expand All @@ -159,11 +159,15 @@ void test_rank0_arrview(Fixture& f) {
arrv(j) = (size_t(f.part[j]) != mpi::comm().rank() ? 0 : f.gidx[j]);
}

arr.syncHostDevice();
if (f.on_device) {
arr.updateDevice();
}

f.halo_exchange.execute<POD, 1>(arr, f.on_device_);
f.halo_exchange.execute<POD, 1>(arr, f.on_device);

arr.syncHostDevice();
if( f.on_device ) {
arr.updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -192,11 +196,15 @@ void test_rank1(Fixture& f) {
arrv(j, 1) = (size_t(f.part[j]) != mpi::comm().rank() ? 0 : f.gidx[j] * 100);
}

arr.syncHostDevice();
if (f.on_device) {
arr.updateDevice();
}

f.halo_exchange.execute<POD, 2>(arr, f.on_device_);
f.halo_exchange.execute<POD, 2>(arr, f.on_device);

arr.syncHostDevice();
if (f.on_device) {
arr.updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -244,11 +252,15 @@ void test_rank1_strided_v1(Fixture& f) {
#endif
}));

arr->syncHostDevice();
if (f.on_device) {
arr->updateDevice();
}

f.halo_exchange.execute<POD, 2>(*arr, f.on_device_);
f.halo_exchange.execute<POD, 2>(*arr, f.on_device);

arr->syncHostDevice();
if (f.on_device) {
arr->updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -295,8 +307,6 @@ void test_rank1_strided_v2(Fixture& f) {
#endif
}));

arr->syncHostDevice();

f.halo_exchange.execute<POD, 2>(*arr, false);

switch (mpi::comm().rank()) {
Expand Down Expand Up @@ -329,11 +339,15 @@ void test_rank2(Fixture& f) {
}
}

arr.syncHostDevice();
if (f.on_device) {
arr.updateDevice();
}

f.halo_exchange.execute<POD, 3>(arr, f.on_device_);
f.halo_exchange.execute<POD, 3>(arr, f.on_device);

arr.syncHostDevice();
if (f.on_device) {
arr.updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -444,7 +458,16 @@ void test_rank2_l2_v2(Fixture& f) {
#endif
}));

f.halo_exchange.execute<POD, 3>(*arr, f.on_device_);

if (f.on_device) {
arr.updateDevice();
}

f.halo_exchange.execute<POD, 3>(*arr, f.on_device);

if (f.on_device) {
arr.updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -504,7 +527,15 @@ void test_rank2_v2(Fixture& f) {
#endif
}));

f.halo_exchange.execute<POD, 3>(*arr, f.on_device_);
if (f.on_device) {
arr.updateDevice();
}

f.halo_exchange.execute<POD, 3>(*arr, f.on_device);

if (f.on_device) {
arr.updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -545,11 +576,15 @@ void test_rank0_wrap(Fixture& f) {
std::unique_ptr<array::Array> arr(array::Array::wrap<POD>(f.gidx.data(), array::make_shape(f.N)));
array::ArrayView<POD, 1> arrv = array::make_view<POD, 1>(*arr);

arr->syncHostDevice();
if (f.on_device) {
arr->updateDevice();
}

f.halo_exchange.execute<POD, 1>(*arr, f.on_device_);
f.halo_exchange.execute<POD, 1>(*arr, f.on_device);

arr->syncHostDevice();
if (f.on_device) {
arr->updateHost();
}

switch (mpi::comm().rank()) {
case 0: {
Expand Down Expand Up @@ -700,7 +735,7 @@ CASE("test_haloexchange") {
SECTION("test_rank1_cinterface") { test_rank1_cinterface(f); }
}

#if ATLAS_GRIDTOOLS_STORAGE_BACKEND_CUDA
#if ATLAS_HAVE_CUDA
CASE("test_haloexchange on device") {
bool on_device = true;
Fixture f(on_device);
Expand Down

0 comments on commit 2d3dd0d

Please sign in to comment.