Skip to content

Commit

Permalink
compiler: Rename and minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 6, 2024
1 parent c620588 commit 9f3c009
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 59 deletions.
121 changes: 65 additions & 56 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
@iet_pass
def optimize_halospots(iet, **kwargs):
"""
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved
around in order to improve the halo exchange performance.
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, hoisted,
merged and moved around in order to improve the halo exchange performance.
"""
iet = _drop_halospots(iet)
iet = _hoist_halospots(iet)
iet = _merge_halospots(iet)
iet = _hoist_invariant(iet)
iet = _hoist_and_merge(iet)
iet = _drop_if_unwritten(iet, **kwargs)
iet = _mark_overlappable(iet)

Expand Down Expand Up @@ -55,10 +55,12 @@ def _drop_halospots(iet):
return iet


def _hoist_halospots(iet):
def _hoist_invariant(iet):
"""
Hoist HaloSpots from inner to outer Iterations where all data dependencies
would be honored.
would be honored. This pass is particularly useful to avoid redundant
halo exchanges when the same data is redundantly exchanged within the
same Iteration tree level.
Example:
haloupd v[t0]
Expand All @@ -80,14 +82,20 @@ def _hoist_halospots(iet):
hsmapper = {}
imapper = defaultdict(list)

# Look for parent Iterations of children HaloSpots
for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items():
for i, hs0 in enumerate(halo_spots):
nodes = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)

# Nothing to do if the HaloSpot is void
if hs0.halo_scheme.is_void:
continue
# Drop void `halo_scheme`s from the analysis
iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
for k, v in nodes.items()}

# Drop pairs that have keys that are None
iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}

for it, halo_spots in iter_mapper.items():
if len(halo_spots) <= 1:
continue

for i, hs0 in enumerate(halo_spots):
for hs1 in halo_spots[i+1:]:
# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
Expand All @@ -100,74 +108,73 @@ def _hoist_halospots(iet):

# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():
# If no time accesses, skip
if not hs1.halo_scheme.fmapper[f].loc_indices:
continue

# If the function is not in both HaloSpots, skip
if f not in hs0.functions:
continue

for it in iters:
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Entering here means we can lift, and we need to update
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in hse.loc_indices:
md = hse.loc_indices[d]
if md.is_Symbol:
root = md.root
hse_min = md.symbolic_min
new_min = hse_min.subs(root, root.symbolic_min)
raw_loc_indices[d] = new_min
else:
# md is in form of an expression
assert d.symbolic_min in md.free_symbols
raw_loc_indices[d] = md

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[it].append(hs1.halo_scheme.project(f))
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in hoist_rules()):
break
else:
# hs1 is lifted out of `it`
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Get the new indexing for the lifted HaloSpot
for d in hse.loc_indices:
md = hse.loc_indices[d]
if md in it.uindices:
new_min = md.symbolic_min.subs(it.dim,
it.dim.symbolic_min)
raw_loc_indices[d] = new_min
else:
raw_loc_indices[d] = md

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[it].append(hs1.halo_scheme.project(f))

mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
for i, hss in imapper.items()}
mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
for i, hs in hsmapper.items()})

iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _merge_halospots(iet):
def _hoist_and_merge(iet):
"""
Merge HaloSpots on the same Iteration tree level where all data dependencies
would be honored.
Hoist and merge HaloSpots on the same Iteration tree level where all data dependencies
would be honored. Helps to avoid redundant halo exchanges when the same data is
redundantly exchanged within the same Iteration tree level as well as to initiate
multiple halo exchanges at once.
Example:
for time for time
haloupd v[t0] haloupd v[t0], h
W v[t1]- R v[t0] W v[t1]- R v[t0]
haloupd v[t0], h
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
"""

# Analysis
cond_mapper = _make_cond_mapper(iet)

mapper = {}
for iter, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
if iter is None or len(halo_spots) <= 1:
for it, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
if it is None or len(halo_spots) <= 1:
continue

scope = Scope([e.expr for e in FindNodes(Expression).visit(iter)])
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])

hs0 = halo_spots[0]
mapper[hs0] = hs0.halo_scheme

for hs1 in halo_spots[1:]:
mapper[hs1] = hs1.halo_scheme

# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
# flow semantics
Expand All @@ -176,12 +183,14 @@ def _merge_halospots(iet):

for f, v in hs1.fmapper.items():
for dep in scope.d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
if not any(r(dep, hs1, v.loc_indices) for r in hoist_rules()):
break
else:
# hs1 is lifted to hs0 level and merged with hs0
hs = hs1.halo_scheme.project(f)
mapper[hs0] = HaloScheme.union([mapper[hs0], hs])
mapper[hs1] = mapper[hs1].drop(f)

mapper[hs0] = HaloScheme.union([mapper.get(hs0, hs0.halo_scheme), hs])
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)

# Post-process analysis
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
Expand Down Expand Up @@ -362,7 +371,7 @@ def _make_cond_mapper(iet):
for hs, v in cond_mapper.items()}


def merge_rules():
def hoist_rules():
# Merge rules -- if the retval is True, then it means the input `dep` is not
# a stopper to halo merging

Expand Down
4 changes: 1 addition & 3 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,7 @@ def test_issue_2448(self, mode):
shape = (2,)
so = 2

x = SpaceDimension(name='x', spacing=Constant(name='h_x',
value=extent[0]/(shape[0]-1)))
grid = Grid(extent=extent, shape=shape, dimensions=(x,))
grid = Grid(extent=extent, shape=shape)

# Time related
tn = 30
Expand Down

0 comments on commit 9f3c009

Please sign in to comment.