diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index b176dfbd8c..b07934e76b 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -429,6 +429,14 @@ def _call_sendrecv(self, name, *args, **kwargs): args = list(args[0].handles) + flatten(args[1:]) return Call(name, args) + def _make_peers(self, d, distributor, nb): + rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions) + rpeer = FieldFromPointer(rname, nb) + lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions) + lpeer = FieldFromPointer(lname, nb) + + return rpeer, lpeer + def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): distributor = f.grid.distributor nb = distributor._obj_neighborhood @@ -443,10 +451,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): if d in fixed: continue - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) + rpeer, lpeer = self._make_peers(d, distributor, nb) if (d, LEFT) in hse.halos: # Sending to left, receiving from right @@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): if d in fixed: continue - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) + rpeer, lpeer = self._make_peers(d, distributor, nb) if (d, LEFT) in hse.halos: # Sending to left, receiving from right @@ -1297,6 +1299,7 @@ def _as_number(self, v, args): return int(subs_op_args(v, args)) def _allocate_buffers(self, f, shape, entry): + # Allocate the send/recv buffers entry.sizes = (c_int*len(shape))(*shape) size = reduce(mul, shape)*dtype_len(self.target.dtype) ctype = dtype_to_ctype(f.dtype) @@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None): if d in fixed: continue - if (d, LEFT) in self.halos: - entry = self.value[i] - i = i + 1 - # Sending to left, receiving from right - shape = mapper[(d, LEFT, OWNED)] - # Allocate the send/recv buffers - self._allocate_buffers(f, shape, entry) - - if (d, RIGHT) in self.halos: - entry = self.value[i] - i = i + 1 - # Sending to right, receiving from left - shape = mapper[(d, RIGHT, OWNED)] - # Allocate the send/recv buffers - self._allocate_buffers(f, shape, entry) + for side in (LEFT, RIGHT): + if (d, side) in self.halos: + entry = self.value[i] + i += 1 + shape = mapper[(d, side, OWNED)] + self._allocate_buffers(f, shape, entry) return {self.name: self.value}