Skip to content

Commit

Permalink
compiler: Rework terminology and minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 7, 2024
1 parent 3531563 commit 0d2a884
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 83 deletions.
4 changes: 2 additions & 2 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
# Nope, let's try with the next Indexed, if any
continue

hse = hse0.rebuild(loc_indices=frozendict(loc_indices),
loc_dirs=frozendict(loc_dirs))
hse = hse0._rebuild(loc_indices=frozendict(loc_indices),
loc_dirs=frozendict(loc_dirs))

else:
continue
Expand Down
159 changes: 85 additions & 74 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
@iet_pass
def optimize_halospots(iet, **kwargs):
"""
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved
around in order to improve the halo exchange performance.
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, hoisted,
merged and moved around in order to improve the halo exchange performance.
"""
iet = _drop_halospots(iet)
iet = _hoist_halospots(iet)
iet = _drop_reduction_halospots(iet)
iet = _hoist_invariant(iet)
iet = _merge_halospots(iet)
iet = _drop_if_unwritten(iet, **kwargs)
iet = _mark_overlappable(iet)

return iet, {}


def _drop_halospots(iet):
def _drop_reduction_halospots(iet):
"""
Remove HaloSpots that:
Expand All @@ -48,17 +48,19 @@ def _drop_halospots(iet):
mapper[hs].add(f)

# Transform the IET introducing the "reduced" HaloSpots
subs = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
for hs in FindNodes(HaloSpot).visit(iet)}
iet = Transformer(subs, nested=True).visit(iet)
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
for hs in FindNodes(HaloSpot).visit(iet)}
iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _hoist_halospots(iet):
def _hoist_invariant(iet):
"""
Hoist HaloSpots from inner to outer Iterations where all data dependencies
would be honored.
would be honored. This pass is particularly useful to avoid redundant
halo exchanges when the same data is redundantly exchanged within the
same Iteration tree level.
Example:
haloupd v[t0]
Expand All @@ -80,18 +82,23 @@ def _hoist_halospots(iet):
hsmapper = {}
imapper = defaultdict(list)

# Look for parent Iterations of children HaloSpots
for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items():
for i, hs0 in enumerate(halo_spots):
nodes = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)

# Nothing to do if the HaloSpot is void
if hs0.halo_scheme.is_void:
continue
# Drop void `halo_scheme`s from the analysis
iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
for k, v in nodes.items()}

# Drop pairs that have keys that are None
iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}

for it, halo_spots in iter_mapper.items():
if len(halo_spots) <= 1:
continue

for i, hs0 in enumerate(halo_spots):
for hs1 in halo_spots[i+1:]:
# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
if cond_mapper.get(hs0) != cond_mapper.get(hs1):

if not ensure_control_flow(hs0, hs1, cond_mapper):
continue

# If there are overlapping time accesses, skip
Expand All @@ -100,88 +107,87 @@ def _hoist_halospots(iet):

# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():
# If no time accesses, skip
if not hs1.halo_scheme.fmapper[f].loc_indices:
continue

# If the function is not in both HaloSpots, skip
if f not in hs0.functions:
continue

for it in iters:
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Entering here means we can lift, and we need to update
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in hse.loc_indices:
md = hse.loc_indices[d]
if md.is_Symbol:
root = md.root
hse_min = md.symbolic_min
new_min = hse_min.subs(root, root.symbolic_min)
raw_loc_indices[d] = new_min
else:
# md is in form of an expression
assert d.symbolic_min in md.free_symbols
raw_loc_indices[d] = md

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[it].append(hs1.halo_scheme.project(f))
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in motion_rules()):
break
else:
# hs1 is lifted out of `it`
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Get the new indexing for the lifted HaloSpot
for d in hse.loc_indices:
md = hse.loc_indices[d]
if md in it.uindices:
new_min = md.symbolic_min.subs(it.dim,
it.dim.symbolic_min)
raw_loc_indices[d] = new_min
else:
raw_loc_indices[d] = md

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[it].append(hs1.halo_scheme.project(f))

mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
for i, hss in imapper.items()}
mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
for i, hs in hsmapper.items()})

iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _merge_halospots(iet):
"""
Merge HaloSpots on the same Iteration tree level where all data dependencies
would be honored.
would be honored. Helps to avoid redundant halo exchanges when the same data is
redundantly exchanged within the same Iteration tree level as well as to initiate
multiple halo exchanges at once.
Example:
for time for time
haloupd v[t0] haloupd v[t0], h
W v[t1]- R v[t0] W v[t1]- R v[t0]
haloupd v[t0], h
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
"""

# Analysis
cond_mapper = _make_cond_mapper(iet)

mapper = {}
for iter, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
if iter is None or len(halo_spots) <= 1:
for it, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
if it is None or len(halo_spots) <= 1:
continue

scope = Scope([e.expr for e in FindNodes(Expression).visit(iter)])
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])

hs0 = halo_spots[0]
mapper[hs0] = hs0.halo_scheme

for hs1 in halo_spots[1:]:
mapper[hs1] = hs1.halo_scheme

# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
# flow semantics
if cond_mapper.get(hs0) != cond_mapper.get(hs1):
if not ensure_control_flow(hs0, hs1, cond_mapper):
continue

for f, v in hs1.fmapper.items():
for dep in scope.d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
if not any(r(dep, hs1, v.loc_indices) for r in motion_rules()):
break
else:
# hs1 is lifted to hs0 level and merged with hs0
hs = hs1.halo_scheme.project(f)
mapper[hs0] = HaloScheme.union([mapper[hs0], hs])
mapper[hs1] = mapper[hs1].drop(f)

mapper[hs0] = HaloScheme.union([mapper.get(hs0, hs0.halo_scheme), hs])
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)

# Post-process analysis
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
Expand Down Expand Up @@ -353,7 +359,7 @@ def mpiize(graph, **kwargs):
make_reductions(graph, mpimode=mpimode, **kwargs)


# Utility functions to avoid code duplication
# *** Utilities

def _make_cond_mapper(iet):
cond_mapper = MapHaloSpots().visit(iet)
Expand All @@ -362,26 +368,31 @@ def _make_cond_mapper(iet):
for hs, v in cond_mapper.items()}


def merge_rules():
# Merge rules -- if the retval is True, then it means the input `dep` is not
# a stopper to halo merging
def ensure_control_flow(hs0, hs1, cond_mapper):
"""
# If there are Conditionals involved, both `hs0` and `hs1` must be
# within the same Conditional, otherwise we would break the control
"""
cond0 = cond_mapper.get(hs0)
cond1 = cond_mapper.get(hs1)

return cond0 != cond1


def motion_rules():
# Code motion rules -- if the retval is True, then it means the input `dep` is not
# a stopper to moving the HaloSpot `hs` around

def rule0(dep, hs, loc_indices):
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>` => True
return not any(d in hs.dimensions or dep.distance_mapper[d] is S.Infinity
for d in dep.cause)

def rule1(dep, hs, loc_indices):
# TODO This is apparently never hit, but feeling uncomfortable to remove it
return (dep.is_regular and
dep.read is not None and
all(not any(dep.read.touched_halo(d.root)) for d in dep.cause))

def rule2(dep, hs, loc_indices):
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v
for d, v in loc_indices.items())

rules = [rule0, rule1, rule2]
rules = [rule0, rule1]

return rules
10 changes: 3 additions & 7 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from conftest import _R, assert_blocking, assert_structure
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
SparseTimeFunction, VectorTimeFunction, TensorTimeFunction,
Dimension, ConditionalDimension, div,
Dimension, ConditionalDimension, div, solve, diag, grad,
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
inner, configuration, switchconfig, generic_derivative,
PrecomputedSparseFunction, DefaultDimension, Buffer,
solve, diag, grad)
PrecomputedSparseFunction, DefaultDimension, Buffer)
from devito.arch.compiler import OneapiCompiler
from devito.data import LEFT, RIGHT
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
Expand All @@ -19,7 +18,6 @@
ComputeCall)
from devito.mpi.distributed import CustomTopology
from devito.tools import Bunch
from devito.types.dimension import SpaceDimension

from examples.seismic.acoustic import acoustic_setup
from examples.seismic import demo_model
Expand Down Expand Up @@ -1017,9 +1015,7 @@ def test_issue_2448(self, mode):
shape = (2,)
so = 2

x = SpaceDimension(name='x', spacing=Constant(name='h_x',
value=extent[0]/(shape[0]-1)))
grid = Grid(extent=extent, shape=shape, dimensions=(x,))
grid = Grid(extent=extent, shape=shape)

# Time related
tn = 30
Expand Down

0 comments on commit 0d2a884

Please sign in to comment.