From a94b1d56a3da9ceffb0112d775dffa196c0f9a52 Mon Sep 17 00:00:00 2001 From: George BIsbas Date: Thu, 19 Dec 2024 14:50:14 +0200 Subject: [PATCH] compiler: Reduce some code after reviews --- benchmarks/user/README.md | 4 ++-- devito/mpi/routines.py | 42 +++++++++++++++++---------------------- tests/test_mpi.py | 2 +- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/benchmarks/user/README.md b/benchmarks/user/README.md index 21fd2c6275..c776970dcc 100644 --- a/benchmarks/user/README.md +++ b/benchmarks/user/README.md @@ -95,8 +95,8 @@ and run with `mpirun -n number_of_processes python benchmark.py ...` Devito supports multiple MPI schemes for halo exchange. -* Devito's most prevalent MPI modes are three: `basic`, `diag2` and `full`. -and are respectively activated e.g., via `DEVITO_MPI=basic`. +* Devito's most prevalent MPI modes are three: `basic2`, `diag2` and `full`. +and are respectively activated e.g., via `DEVITO_MPI=basic2`. These modes may perform better under different factors such as arithmetic intensity, or number of fields used in the computation. diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index b176dfbd8c..c47d3923bf 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): if d in fixed: continue - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) + rpeer, lpeer = self._make_peers(d, distributor, nb) if (d, LEFT) in hse.halos: # Sending to left, receiving from right @@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed): return mapper + def _make_peers(self, d, distributor, nb): + rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions) + rpeer = FieldFromPointer(rname, nb) + lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions) + lpeer = FieldFromPointer(lname, nb) + + return rpeer, lpeer + def _call_haloupdate(self, name, f, hse, *args): comm = f.grid.distributor._obj_comm nb = f.grid.distributor._obj_neighborhood @@ -537,7 +542,7 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits): class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder): """ - A BasicHaloExchangeBuilder making use of pre-allocated buffers for + A BasicHaloExchangeBuilder using pre-allocated buffers for message size. Generates: @@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): if d in fixed: continue - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) + rpeer, lpeer = self._make_peers(d, distributor, nb) if (d, LEFT) in hse.halos: # Sending to left, receiving from right @@ -1297,6 +1299,7 @@ def _as_number(self, v, args): return int(subs_op_args(v, args)) def _allocate_buffers(self, f, shape, entry): + # Allocate the send/recv buffers entry.sizes = (c_int*len(shape))(*shape) size = reduce(mul, shape)*dtype_len(self.target.dtype) ctype = dtype_to_ctype(f.dtype) @@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None): if d in fixed: continue - if (d, LEFT) in self.halos: - entry = self.value[i] - i = i + 1 - # Sending to left, receiving from right - shape = mapper[(d, LEFT, OWNED)] - # Allocate the send/recv buffers - self._allocate_buffers(f, shape, entry) - - if (d, RIGHT) in self.halos: - entry = self.value[i] - i = i + 1 - # Sending to right, receiving from left - shape = mapper[(d, RIGHT, OWNED)] - # Allocate the send/recv buffers - self._allocate_buffers(f, shape, entry) + for side in (LEFT, RIGHT): + if (d, side) in self.halos: + entry = self.value[i] + i += 1 + shape = mapper[(d, side, OWNED)] + self._allocate_buffers(f, shape, entry) return {self.name: self.value} diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 2de796f0b7..da1967ee0f 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2269,7 +2269,7 @@ def test_haloupdate_issue_1613(self, mode): assert dims[0].is_Modulo assert dims[0].origin is t - @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'diag2'), (4, 'overlap2')]) + @pytest.mark.parallel(mode=[(4, 'basic2'), (4, 'diag2'), (4, 'overlap2')]) def test_cire(self, mode): """ Check correctness when the DSE extracts aliases and places them