From 2e36b6b16bc5d8d02110dbf904de87cda621ddec Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Mon, 16 Dec 2024 14:58:12 +0000 Subject: [PATCH] compiler: class HaloSchemeEntry(EnrichedTuple) --- devito/ir/iet/nodes.py | 3 +++ devito/mpi/halo_scheme.py | 22 +++++----------------- devito/passes/iet/mpi.py | 4 +++- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 4aa3df1f2b..57becab5ff 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1437,6 +1437,8 @@ def DummyExpr(*args, init=False): # Nodes required for distributed-memory halo exchange + + class HaloSpot(Node): """ @@ -1498,6 +1500,7 @@ def body(self): def functions(self): return tuple(self.fmapper) + # Utility classes diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 3819caccb4..09cf98ac6a 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -11,7 +11,7 @@ from devito.ir.support import Forward, Scope from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, - frozendict, is_integer, filter_sorted) + frozendict, is_integer, filter_sorted, EnrichedTuple) from devito.types import Grid __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch'] @@ -28,34 +28,22 @@ class HaloLabel(Tag): STENCIL = HaloLabel('stencil') -class HaloSchemeEntry(Reconstructable): +class HaloSchemeEntry(EnrichedTuple): __rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims') - def __init__(self, loc_indices, loc_dirs, halos, dims): + def __init__(self, loc_indices, loc_dirs, halos, dims, getters=None): self.loc_indices = frozendict(loc_indices) self.loc_dirs = frozendict(loc_dirs) self.halos = frozenset(halos) self.dims = frozenset(dims) - def __eq__(self, other): - if not isinstance(other, HaloSchemeEntry): - return False - return (self.loc_indices == other.loc_indices and - self.loc_dirs == other.loc_dirs and - self.halos == other.halos and - self.dims == other.dims) - def __hash__(self): - return hash((tuple(self.loc_indices.items()), - tuple(self.loc_dirs.items()), + return hash((self.loc_indices, + self.loc_dirs, self.halos, self.dims)) - def __repr__(self): - return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, " - f"loc_dirs={self.loc_dirs}, halos={self.halos}, dims={self.dims})") - Halo = namedtuple('Halo', 'dim side') diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 3bc6a61344..930cc108b2 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -356,7 +356,9 @@ def _filter_iter_mapper(iet): def _make_cond_mapper(iet): - "Return a mapper from HaloSpots to the Conditionals that contain them." + """ + Return a mapper from HaloSpots to the Conditionals that contain them. + """ cond_mapper = {} for hs, v in MapHaloSpots().visit(iet).items(): conditionals = {i for i in v if i.is_Conditional and