Skip to content

Commit

Permalink
compiler: Revert __repr__, add more minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 16, 2024
1 parent 8bfceeb commit 018ba54
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 35 deletions.
9 changes: 4 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,6 @@ def DummyExpr(*args, init=False):


# Nodes required for distributed-memory halo exchange

class HaloSpot(Node):

"""
Expand All @@ -1463,6 +1462,10 @@ def __init__(self, body, halo_scheme):

self._halo_scheme = halo_scheme

def __repr__(self):
functions = "(%s)" % ",".join(i.name for i in self.functions)
return "<%s%s>" % (self.__class__.__name__, functions)

@property
def halo_scheme(self):
return self._halo_scheme
Expand Down Expand Up @@ -1495,10 +1498,6 @@ def body(self):
def functions(self):
return tuple(self.fmapper)

def __repr__(self):
funcs = self.halo_scheme.__reprfuncs__()
return "<%s(%s)>" % (self.__class__.__name__, funcs)

# Utility classes


Expand Down
18 changes: 4 additions & 14 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from devito.ir.support import Forward, Scope
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted, OrderedSet)
frozendict, is_integer, filter_sorted)
from devito.types import Grid

__all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
Expand Down Expand Up @@ -62,7 +62,7 @@ def __repr__(self):
OMapper = namedtuple('OMapper', 'core owned')


class HaloScheme():
class HaloScheme:

"""
A HaloScheme describes a set of halo exchanges through a mapper:
Expand Down Expand Up @@ -120,17 +120,9 @@ def __init__(self, exprs, ispace):
self._honored[i.root] = frozenset([(ltk, rtk)])
self._honored = frozendict(self._honored)

def __reprfuncs__(self):
fstrings = []
for f in self.fmapper.keys():
loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""
fstrings.append("%s%s" % (f.name, loc_indices_str))

return ",".join(fstrings)

def __repr__(self):
return "<%s(%s)>" % (self.__class__.__name__, self.__reprfuncs__())
fnames = ",".join(i.name for i in set(self._mapper))
return "HaloScheme<%s>" % fnames

def __eq__(self, other):
return (isinstance(other, HaloScheme) and
Expand Down Expand Up @@ -545,8 +537,6 @@ def classify(exprs, ispace):

loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
ispace.directions)
halos = frozenset(halos)
dims = frozenset(dims)

mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)

Expand Down
20 changes: 8 additions & 12 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,18 @@ def _hoist_invariant(iet):
if not any(r(dep, hs1, v.loc_indices) for r in rules):
break
else:
# hs1 can be hoisted out of `it`, but we need to infer valid
# `hs1`` can be hoisted out of `it`, but we need to infer valid
# loc_indices
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
loc_indices = {}

for d, v in hse.loc_indices.items():
if v in it.uindices:
v_sub = it.start
raw_loc_indices[d] = v.symbolic_min.subs(it.dim, v_sub)
loc_indices[d] = v.symbolic_min.subs(it.dim, it.start)
else:
raw_loc_indices[d] = v
loc_indices[d] = v

hse = hse._rebuild(loc_indices=raw_loc_indices)
hse = hse._rebuild(loc_indices=loc_indices)
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
Expand Down Expand Up @@ -357,14 +356,11 @@ def _filter_iter_mapper(iet):


def _make_cond_mapper(iet):

"Return a mapper from HaloSpots to the Conditionals that contain them."
cond_mapper = {}
for hs, v in MapHaloSpots().visit(iet).items():
conditionals = set()
for i in v:
if i.is_Conditional and not isinstance(i.condition, GuardFactorEq):
conditionals.add(i)

conditionals = {i for i in v if i.is_Conditional and
not isinstance(i.condition, GuardFactorEq)}
cond_mapper[hs] = conditionals

return cond_mapper
Expand Down
9 changes: 5 additions & 4 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def test_unmerge_haloupdate_if_no_locindices(self, mode):
assert np.allclose(g.data_ro_domain[0, 5:], [16., 16., 14., 13., 6.], rtol=R)

@pytest.mark.parallel(mode=1)
def test_merge_haloupdate_if_diff_locindices_v0(self, mode):
def test_merge_haloupdate_if_diff_locindices(self, mode):
grid = Grid(shape=(101, 101))
x, y = grid.dimensions
t = grid.stepping_dim
Expand All @@ -1320,11 +1320,12 @@ def test_merge_haloupdate_if_diff_locindices_v0(self, mode):
op.cfunction

@pytest.mark.parallel(mode=2)
def test_merge_haloupdate_if_diff_locindices_v1(self, mode):
def test_merge_and_hoist_haloupdate_if_diff_locindices(self, mode):
"""
This test is a revisited, more complex version of
`test_merge_haloupdate_if_diff_locindices_v0`. And in addition to
checking the generated code, it also checks the numerical output.
`test_merge_haloupdate_if_diff_locindices`, also checking hoisting.
And in addition to checking the generated code,
it also checks the numerical output.
In the Operator there are three Eqs:
Expand Down

0 comments on commit 018ba54

Please sign in to comment.