diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 378d0be721..72ff6e3c04 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -16,7 +16,7 @@ Forward, WithLock, PrefetchUpdate, detect_io) from devito.symbolics import ListInitializer, CallFromPointer, ccode from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten, - ctypes_to_cstr) + ctypes_to_cstr, OrderedSet) from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed, Symbol) from devito.types.object import AbstractObject, LocalObject @@ -1459,7 +1459,7 @@ def __init__(self, body, halo_scheme): def __repr__(self): fstrings = [] for f in self.functions: - loc_indices = set().union(*[self.halo_scheme.fmapper[f].loc_indices.values()]) + loc_indices = OrderedSet(*(self.halo_scheme.fmapper[f].loc_indices.values())) loc_indices_str = str(list(loc_indices)) if loc_indices else "" fstrings.append(f"{f.name}{loc_indices_str}") diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index cc9edbe416..64a47623ff 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -11,7 +11,7 @@ from devito.ir.support import Forward, Scope from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, - frozendict, is_integer, filter_sorted) + frozendict, is_integer, filter_sorted, OrderedSet) from devito.types import Grid __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch'] @@ -28,7 +28,9 @@ class HaloLabel(Tag): STENCIL = HaloLabel('stencil') -class HaloSchemeEntry: +class HaloSchemeEntry(Reconstructable): + + __rkwargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims') def __init__(self, loc_indices, loc_dirs, halos, dims): self.loc_indices = loc_indices @@ -36,10 +38,6 @@ def __init__(self, loc_indices, loc_dirs, halos, dims): self.halos = halos self.dims = dims - def __repr__(self): - return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, " - f"loc_dirs={self.loc_dirs}, halos={self.halos}, dims={self.dims})") - def __eq__(self, other): if not isinstance(other, HaloSchemeEntry): return False @@ -54,12 +52,9 @@ def __hash__(self): frozenset(self.halos), frozenset(self.dims))) - def rebuild(self, **kwargs): - loc_indices = kwargs.get('loc_indices', self.loc_indices) - loc_dirs = kwargs.get('loc_dirs', self.loc_dirs) - halos = kwargs.get('halos', self.halos) - dims = kwargs.get('dims', self.dims) - return HaloSchemeEntry(loc_indices, loc_dirs, halos, dims) + def __repr__(self): + return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, " + f"loc_dirs={self.loc_dirs}, halos={self.halos}, dims={self.dims})") Halo = namedtuple('Halo', 'dim side') @@ -129,7 +124,7 @@ def __init__(self, exprs, ispace): def __repr__(self): fstrings = [] for f in self.fmapper: - loc_indices = set().union(*[self._mapper[f].loc_indices.values()]) + loc_indices = OrderedSet(*(self._mapper[f].loc_indices.values())) loc_indices_str = str(list(loc_indices)) if loc_indices else "" fstrings.append(f"{f.name}{loc_indices_str}") diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 323b390c7f..72e351bc84 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -109,7 +109,6 @@ def _hoist_halospots(iet): continue for it in iters: - # If also merge-able we can start hoisting the latter for dep in scopes[it].d_flow.project(f): if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): break @@ -120,17 +119,18 @@ def _hoist_halospots(iet): # the loc_indices with known values # TODO: Can I get this in a more elegant way? for d in hse.loc_indices: - if hse.loc_indices[d].is_Symbol: - assert d in hse.loc_indices[d]._defines - root_min = hse.loc_indices[d].symbolic_min - new_min = root_min.subs(hse.loc_indices[d].root, - hse.loc_indices[d].root.symbolic_min) + md = hse.loc_indices[d] + if md.is_Symbol: + root = md.root + hse_min = md.symbolic_min + new_min = hse_min.subs(root, root.symbolic_min) raw_loc_indices[d] = new_min else: - assert d.symbolic_min in hse.loc_indices[d].free_symbols - raw_loc_indices[d] = hse.loc_indices[d] + # md is in form of an expression + assert d.symbolic_min in md.free_symbols + raw_loc_indices[d] = md - hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices)) + hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices)) hs1.halo_scheme.fmapper[f] = hse hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)