Skip to content

Commit

Permalink
compiler: Rework terminology and minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 8, 2024
1 parent 3531563 commit 5951eaf
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 96 deletions.
2 changes: 1 addition & 1 deletion devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def __repr__(self):

functions = ",".join(fstrings)

return f"<{self.__class__.__name__}({functions})>"
return "<%s(%s)>" % (self.__class__.__name__, functions)

@property
def halo_scheme(self):
Expand Down
6 changes: 3 additions & 3 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __repr__(self):

functions = ",".join(fstrings)

return f"<{self.__class__.__name__}({functions})>"
return "<%s(%s)>" % (self.__class__.__name__, functions)

def __eq__(self, other):
return (isinstance(other, HaloScheme) and
Expand Down Expand Up @@ -677,8 +677,8 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
# Nope, let's try with the next Indexed, if any
continue

hse = hse0.rebuild(loc_indices=frozendict(loc_indices),
loc_dirs=frozendict(loc_dirs))
hse = hse0._rebuild(loc_indices=frozendict(loc_indices),
loc_dirs=frozendict(loc_dirs))

else:
continue
Expand Down
189 changes: 106 additions & 83 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict

from sympy import S
from itertools import combinations

from devito.ir.iet import (Call, Expression, HaloSpot, Iteration, FindNodes,
MapNodes, MapHaloSpots, Transformer,
Expand All @@ -19,19 +20,19 @@
@iet_pass
def optimize_halospots(iet, **kwargs):
"""
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved
around in order to improve the halo exchange performance.
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, hoisted,
merged and moved around in order to improve the halo exchange performance.
"""
iet = _drop_halospots(iet)
iet = _hoist_halospots(iet)
iet = _drop_reduction_halospots(iet)
iet = _hoist_invariant(iet)
iet = _merge_halospots(iet)
iet = _drop_if_unwritten(iet, **kwargs)
iet = _mark_overlappable(iet)

return iet, {}


def _drop_halospots(iet):
def _drop_reduction_halospots(iet):
"""
Remove HaloSpots that:
Expand All @@ -48,17 +49,19 @@ def _drop_halospots(iet):
mapper[hs].add(f)

# Transform the IET introducing the "reduced" HaloSpots
subs = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
for hs in FindNodes(HaloSpot).visit(iet)}
iet = Transformer(subs, nested=True).visit(iet)
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
for hs in FindNodes(HaloSpot).visit(iet)}
iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _hoist_halospots(iet):
def _hoist_invariant(iet):
"""
Hoist HaloSpots from inner to outer Iterations where all data dependencies
would be honored.
would be honored. This pass is particularly useful to avoid redundant
halo exchanges when the same data is redundantly exchanged within the
same Iteration tree level.
Example:
haloupd v[t0]
Expand All @@ -80,108 +83,123 @@ def _hoist_halospots(iet):
hsmapper = {}
imapper = defaultdict(list)

# Look for parent Iterations of children HaloSpots
for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items():
for i, hs0 in enumerate(halo_spots):
iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)

# Drop void `halo_scheme`s from the analysis
iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
for k, v in iter_mapper.items()}

# Drop pairs that have keys that are None
iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}

# Drop iter_mapper pairs where len(halo_spots) <= 1
iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}

for it, halo_spots in iter_mapper.items():

# Nothing to do if the HaloSpot is void
if hs0.halo_scheme.is_void:
for hs0, hs1 in combinations(halo_spots, r=2):

if ensure_control_flow(hs0, hs1, cond_mapper):
continue

for hs1 in halo_spots[i+1:]:
# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
if cond_mapper.get(hs0) != cond_mapper.get(hs1):
continue
# If there are overlapping time accesses, skip
hs0_mdims = hs0.halo_scheme.loc_values
hs1_mdims = hs1.halo_scheme.loc_values
if hs0_mdims.intersection(hs1_mdims):
continue

# If there are overlapping time accesses, skip
if hs0.halo_scheme.loc_values.intersection(hs1.halo_scheme.loc_values):
# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():

# If the function is not in both HaloSpots, skip
if f not in hs0.functions:
continue

# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():
# If no time accesses, skip
if not hs1.halo_scheme.fmapper[f].loc_indices:
continue
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in motion_rules()):
break
else:
# hs1 is lifted out of `it`, and we need to get
# the new indexing for the HaloSpot
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}

for d in hse.loc_indices:
md = hse.loc_indices[d]
if md in it.uindices:
new_min = md.symbolic_min.subs(it.dim,
it.dim.symbolic_min)
raw_loc_indices[d] = new_min
else:
raw_loc_indices[d] = md

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

# If the function is not in both HaloSpots, skip
if f not in hs0.functions:
continue
hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)

for it in iters:
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Entering here means we can lift, and we need to update
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in hse.loc_indices:
md = hse.loc_indices[d]
if md.is_Symbol:
root = md.root
hse_min = md.symbolic_min
new_min = hse_min.subs(root, root.symbolic_min)
raw_loc_indices[d] = new_min
else:
# md is in form of an expression
assert d.symbolic_min in md.free_symbols
raw_loc_indices[d] = md

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[it].append(hs1.halo_scheme.project(f))
imapper[it].append(hs1.halo_scheme.project(f))

mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
for i, hss in imapper.items()}

mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
for i, hs in hsmapper.items()})

iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _merge_halospots(iet):
"""
Merge HaloSpots on the same Iteration tree level where all data dependencies
would be honored.
would be honored. Helps to avoid redundant halo exchanges when the same data is
redundantly exchanged within the same Iteration tree level as well as to initiate
multiple halo exchanges at once.
Example:
for time for time
haloupd v[t0] haloupd v[t0], h
W v[t1]- R v[t0] W v[t1]- R v[t0]
haloupd v[t0], h
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
"""

# Analysis
cond_mapper = _make_cond_mapper(iet)

mapper = {}
for iter, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
if iter is None or len(halo_spots) <= 1:
continue

scope = Scope([e.expr for e in FindNodes(Expression).visit(iter)])
iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)

# Drop pairs that have keys that are None
iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}

# Drop iter_mapper pairs where len(halo_spots) <= 1
iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}

for it, halo_spots in iter_mapper.items():

scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])

hs0 = halo_spots[0]
mapper[hs0] = hs0.halo_scheme

for hs1 in halo_spots[1:]:
mapper[hs1] = hs1.halo_scheme

# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
# flow semantics
if cond_mapper.get(hs0) != cond_mapper.get(hs1):
if ensure_control_flow(hs0, hs1, cond_mapper):
continue

for f, v in hs1.fmapper.items():
for dep in scope.d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
if not any(r(dep, hs1, v.loc_indices) for r in motion_rules()):
break
else:
# hs1 is merged with hs0
hs = hs1.halo_scheme.project(f)
mapper[hs0] = HaloScheme.union([mapper[hs0], hs])
mapper[hs1] = mapper[hs1].drop(f)
mapper[hs0] = HaloScheme.union([mapper.get(hs0, hs0.halo_scheme), hs])
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)

# Post-process analysis
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
Expand Down Expand Up @@ -353,7 +371,7 @@ def mpiize(graph, **kwargs):
make_reductions(graph, mpimode=mpimode, **kwargs)


# Utility functions to avoid code duplication
# *** Utilities

def _make_cond_mapper(iet):
cond_mapper = MapHaloSpots().visit(iet)
Expand All @@ -362,26 +380,31 @@ def _make_cond_mapper(iet):
for hs, v in cond_mapper.items()}


def merge_rules():
# Merge rules -- if the retval is True, then it means the input `dep` is not
# a stopper to halo merging
def ensure_control_flow(hs0, hs1, cond_mapper):
"""
# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
"""
cond0 = cond_mapper.get(hs0)
cond1 = cond_mapper.get(hs1)

return cond0 != cond1


def motion_rules():
# Code motion rules -- if the retval is True, then it means the input `dep` is not
# a stopper to moving the HaloSpot `hs` around

def rule0(dep, hs, loc_indices):
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>` => True
return not any(d in hs.dimensions or dep.distance_mapper[d] is S.Infinity
for d in dep.cause)

def rule1(dep, hs, loc_indices):
# TODO This is apparently never hit, but feeling uncomfortable to remove it
return (dep.is_regular and
dep.read is not None and
all(not any(dep.read.touched_halo(d.root)) for d in dep.cause))

def rule2(dep, hs, loc_indices):
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v
for d, v in loc_indices.items())

rules = [rule0, rule1, rule2]
rules = [rule0, rule1]

return rules
Loading

0 comments on commit 5951eaf

Please sign in to comment.