Skip to content

Commit

Permalink
compiler: Reduce some code after reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 19, 2024
1 parent 2ed05aa commit a94b1d5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
4 changes: 2 additions & 2 deletions benchmarks/user/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ and run with `mpirun -n number_of_processes python benchmark.py ...`

Devito supports multiple MPI schemes for halo exchange.

* Devito's most prevalent MPI modes are three: `basic`, `diag2` and `full`.
and are respectively activated e.g., via `DEVITO_MPI=basic`.
* Devito's most prevalent MPI modes are three: `basic2`, `diag2` and `full`.
and are respectively activated e.g., via `DEVITO_MPI=basic2`.
These modes may perform better under different factors such as arithmetic intensity,
or number of fields used in the computation.

Expand Down
42 changes: 18 additions & 24 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
if d in fixed:
continue

name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(name, nb)
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(name, nb)
rpeer, lpeer = self._make_peers(d, distributor, nb)

if (d, LEFT) in hse.halos:
# Sending to left, receiving from right
Expand Down Expand Up @@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed):

return mapper

def _make_peers(self, d, distributor, nb):
rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(rname, nb)
lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(lname, nb)

return rpeer, lpeer

def _call_haloupdate(self, name, f, hse, *args):
comm = f.grid.distributor._obj_comm
nb = f.grid.distributor._obj_neighborhood
Expand Down Expand Up @@ -537,7 +542,7 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits):
class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):

"""
A BasicHaloExchangeBuilder making use of pre-allocated buffers for
A BasicHaloExchangeBuilder using pre-allocated buffers for
message size.
Generates:
Expand Down Expand Up @@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
if d in fixed:
continue

name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(name, nb)
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(name, nb)
rpeer, lpeer = self._make_peers(d, distributor, nb)

if (d, LEFT) in hse.halos:
# Sending to left, receiving from right
Expand Down Expand Up @@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
return int(subs_op_args(v, args))

def _allocate_buffers(self, f, shape, entry):
# Allocate the send/recv buffers
entry.sizes = (c_int*len(shape))(*shape)
size = reduce(mul, shape)*dtype_len(self.target.dtype)
ctype = dtype_to_ctype(f.dtype)
Expand Down Expand Up @@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
if d in fixed:
continue

if (d, LEFT) in self.halos:
entry = self.value[i]
i = i + 1
# Sending to left, receiving from right
shape = mapper[(d, LEFT, OWNED)]
# Allocate the send/recv buffers
self._allocate_buffers(f, shape, entry)

if (d, RIGHT) in self.halos:
entry = self.value[i]
i = i + 1
# Sending to right, receiving from left
shape = mapper[(d, RIGHT, OWNED)]
# Allocate the send/recv buffers
self._allocate_buffers(f, shape, entry)
for side in (LEFT, RIGHT):
if (d, side) in self.halos:
entry = self.value[i]
i += 1
shape = mapper[(d, side, OWNED)]
self._allocate_buffers(f, shape, entry)

return {self.name: self.value}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,7 +2269,7 @@ def test_haloupdate_issue_1613(self, mode):
assert dims[0].is_Modulo
assert dims[0].origin is t

@pytest.mark.parallel(mode=[(4, 'basic'), (4, 'diag2'), (4, 'overlap2')])
@pytest.mark.parallel(mode=[(4, 'basic2'), (4, 'diag2'), (4, 'overlap2')])
def test_cire(self, mode):
"""
Check correctness when the DSE extracts aliases and places them
Expand Down

0 comments on commit a94b1d5

Please sign in to comment.