Skip to content

Commit

Permalink
compiler: Reduce some code after reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 19, 2024
1 parent 2ed05aa commit a54684a
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
if d in fixed:
continue

name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(name, nb)
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(name, nb)
rpeer, lpeer = self._make_peers(d, distributor, nb)

if (d, LEFT) in hse.halos:
# Sending to left, receiving from right
Expand Down Expand Up @@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed):

return mapper

def _make_peers(self, d, distributor, nb):
rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(rname, nb)
lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(lname, nb)

return rpeer, lpeer

def _call_haloupdate(self, name, f, hse, *args):
comm = f.grid.distributor._obj_comm
nb = f.grid.distributor._obj_neighborhood
Expand Down Expand Up @@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
if d in fixed:
continue

name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(name, nb)
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(name, nb)
rpeer, lpeer = self._make_peers(d, distributor, nb)

if (d, LEFT) in hse.halos:
# Sending to left, receiving from right
Expand Down Expand Up @@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
return int(subs_op_args(v, args))

def _allocate_buffers(self, f, shape, entry):
# Allocate the send/recv buffers
entry.sizes = (c_int*len(shape))(*shape)
size = reduce(mul, shape)*dtype_len(self.target.dtype)
ctype = dtype_to_ctype(f.dtype)
Expand Down Expand Up @@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
if d in fixed:
continue

if (d, LEFT) in self.halos:
entry = self.value[i]
i = i + 1
# Sending to left, receiving from right
shape = mapper[(d, LEFT, OWNED)]
# Allocate the send/recv buffers
self._allocate_buffers(f, shape, entry)

if (d, RIGHT) in self.halos:
entry = self.value[i]
i = i + 1
# Sending to right, receiving from left
shape = mapper[(d, RIGHT, OWNED)]
# Allocate the send/recv buffers
self._allocate_buffers(f, shape, entry)
for side in (LEFT, RIGHT):
if (d, side) in self.halos:
entry = self.value[i]
i += 1
shape = mapper[(d, side, OWNED)]
self._allocate_buffers(f, shape, entry)

return {self.name: self.value}

Expand Down

0 comments on commit a54684a

Please sign in to comment.