Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Introduce symbolic fencing #2244

Merged
merged 12 commits into from
Oct 25, 2023
56 changes: 43 additions & 13 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from devito.ir.equations import ClusterizedEq
from devito.ir.support import (PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext,
Forward, Interval, IntervalGroup, IterationSpace,
DataSpace, Guards, Properties, Scope, detect_accesses,
detect_io, normalize_properties, normalize_syncs,
minimum, maximum, null_ispace)
DataSpace, Guards, Properties, Scope, WithLock,
PrefetchUpdate, detect_accesses, detect_io,
normalize_properties, normalize_syncs, minimum,
maximum, null_ispace)
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, frozendict, infer_dtype
from devito.types import WeakFence, CriticalRegion

__all__ = ["Cluster", "ClusterGroup"]

Expand Down Expand Up @@ -176,10 +178,6 @@ def functions(self):
def has_increments(self):
return any(e.is_Increment for e in self.exprs)

@cached_property
def is_scalar(self):
return not any(f.is_Function for f in self.scope.writes)

@cached_property
def grid(self):
grids = set(f.grid for f in self.functions if f.is_DiscreteFunction) - {None}
Expand All @@ -188,15 +186,21 @@ def grid(self):
else:
raise ValueError("Cluster has no unique Grid")

@cached_property
def is_scalar(self):
return not any(f.is_Function for f in self.scope.writes)

@cached_property
def is_dense(self):
"""
A Cluster is dense if at least one of the following conditions is True:
True if at least one of the following conditions are True:

* It is defined over a unique Grid and all of the Grid Dimensions
are PARALLEL.
* Only DiscreteFunctions are written and only affine index functions
are used (e.g., `a[x+1, y-2]` is OK, while `a[b[x], y-2]` is not)

False in all other cases.
"""
# Hopefully it's got a unique Grid and all Dimensions are PARALLEL (or
# at most PARALLEL_IF_PVT). This is a quick and easy check so we try it first
Expand All @@ -212,21 +216,47 @@ def is_dense(self):
# Fallback to legacy is_dense checks
return (not any(e.conditionals for e in self.exprs) and
not any(f.is_SparseFunction for f in self.functions) and
not self.is_halo_touch and
not self.is_wild and
all(a.is_regular for a in self.scope.accesses))

@cached_property
def is_sparse(self):
"""
A Cluster is sparse if it represents a sparse operation, i.e iff
There's at least one irregular access.
True if it represents a sparse operation, i.e iff there's at least
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra f on "iff"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if and only if?

one irregular access, False otherwise.
"""
return any(a.is_irregular for a in self.scope.accesses)

@property
def is_wild(self):
"""
True if encoding a non-mathematical operation, False otherwise.
"""
return self.is_halo_touch or self.is_fence

@property
def is_halo_touch(self):
return (len(self.exprs) > 0 and
all(isinstance(e.rhs, HaloTouch) for e in self.exprs))
return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs)

@property
def is_fence(self):
return self.is_weak_fence or self.is_critical_region

@property
def is_weak_fence(self):
return self.exprs and all(isinstance(e.rhs, WeakFence) for e in self.exprs)

@property
def is_critical_region(self):
return self.exprs and all(isinstance(e.rhs, CriticalRegion) for e in self.exprs)

@property
def is_async(self):
"""
True if an asynchronous Cluster, False otherwise.
"""
return any(isinstance(s, (WithLock, PrefetchUpdate))
for s in flatten(self.syncs.values()))

@cached_property
def dtype(self):
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/clusters/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ def __init__(self, func, mode='dense'):
self.func = func

if mode == 'dense':
self.cond = lambda c: c.is_dense or not c.is_sparse
self.cond = lambda c: (c.is_dense or not c.is_sparse) and not c.is_wild
elif mode == 'sparse':
self.cond = lambda c: c.is_sparse
self.cond = lambda c: c.is_sparse and not c.is_wild
else:
self.cond = lambda c: True

Expand Down
18 changes: 16 additions & 2 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,29 @@ def stree_build(clusters, profiler=None, **kwargs):

def preprocess(clusters, options=None, **kwargs):
"""
Remove the HaloTouch's from `clusters` and create a mapping associating
each removed HaloTouch to the first Cluster necessitating it.
Lower the so-called "wild" Clusters, that is objects not representing a set
of mathematical operations. This boils down to:

* Moving the HaloTouch's from `clusters` into a mapper `M: {HT -> C}`.
`c = M(ht)` is the first Cluster of the sequence requiring the halo
exchange `ht` to have terminated before the execution can proceed.
* Lower the CriticalRegions:
* If they encode an asynchronous operation (e.g., a WaitLock), attach
it to a Nop Cluster for future lowering;
* Otherwise, simply remove them, as they have served their purpose
at this point.
* Remove the WeakFences, as they have served their purpose at this point.
"""
queue = []
processed = []
for c in clusters:
if c.is_halo_touch:
hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs)
queue.append(c.rebuild(halo_scheme=hs))
elif c.is_critical_region and c.syncs:
processed.append(c.rebuild(exprs=None, guards=c.guards, syncs=c.syncs))
elif c.is_wild:
continue
else:
dims = set(c.ispace.promote(lambda d: d.is_Block).itdims)

Expand Down
60 changes: 47 additions & 13 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
q_constant, q_affine, q_routine, search, uxreplace)
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
flatten, memoized_meth, memoized_generator)
from devito.types import (Barrier, ComponentAccess, Dimension, DimensionTuple,
Function, Jump, Symbol, Temp, TempArray, TBArray)
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
CriticalRegion, Function, Symbol, Temp, TempArray,
TBArray)

__all__ = ['IterationInstance', 'TimedAccess', 'Scope', 'ExprGeometry']

Expand All @@ -23,10 +24,9 @@ class IndexMode(Tag):
REGULAR = IndexMode('regular')
IRREGULAR = IndexMode('irregular')

mocksym = Symbol(name='⋈')
"""
A Symbol to create mock data depdendencies.
"""
# Symbols to create mock data depdendencies
mocksym0 = Symbol(name='__⋈_0__')
mocksym1 = Symbol(name='__⋈_1__')


class IterationInstance(LabeledVector):
Expand Down Expand Up @@ -848,9 +848,21 @@ def writes_gen(self):

# Objects altering the control flow (e.g., synchronization barriers,
# break statements, ...) are converted into mock dependences

# Fences (any sort) cannot float around upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, (Barrier, Jump)):
yield TimedAccess(mocksym, 'W', i, e.ispace)
if isinstance(e.rhs, Fence):
yield TimedAccess(mocksym0, 'W', i, e.ispace)

# CriticalRegions are stronger than plain Fences.
# We must also ensure that none of the Eqs within an opening-closing
# CriticalRegion pair floats outside upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, CriticalRegion) and e.rhs.opening:
for j, e1 in enumerate(self.exprs[i+1:], 1):
if isinstance(e1.rhs, CriticalRegion) and e1.rhs.closing:
break
yield TimedAccess(mocksym1, 'W', i+j, e1.ispace)

@cached_property
def writes(self):
Expand Down Expand Up @@ -904,12 +916,32 @@ def reads_implicit_gen(self):
for i in symbols:
yield TimedAccess(i, 'R', -1)

@memoized_generator
def reads_synchro_gen(self):
"""
Generate all reads due to syncronization operations. These may be explicit
or implicit.
"""
# Objects altering the control flow (e.g., synchronization barriers,
# break statements, ...) are converted into mock dependences

# Fences (any sort) cannot float around upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, Fence):
if i > 0:
yield TimedAccess(mocksym0, 'R', i-1, e.ispace)
if i < len(self.exprs)-1:
yield TimedAccess(mocksym0, 'R', i+1, e.ispace)

# CriticalRegions are stronger than plain Fences.
# We must also ensure that none of the Eqs within an opening-closing
# CriticalRegion pair floats outside upon topological sorting
for i, e in enumerate(self.exprs):
if isinstance(e.rhs, (Barrier, Jump)):
yield TimedAccess(mocksym, 'R', max(i, 0), e.ispace)
yield TimedAccess(mocksym, 'R', i+1, e.ispace)
if isinstance(e.rhs, CriticalRegion):
if e.rhs.opening and i > 0:
yield TimedAccess(mocksym1, 'R', i-1, self.exprs[i-1].ispace)
elif e.rhs.closing and i < len(self.exprs)-1:
yield TimedAccess(mocksym1, 'R', i+1, self.exprs[i+1].ispace)

@memoized_generator
def reads_gen(self):
Expand All @@ -920,7 +952,9 @@ def reads_gen(self):
# is efficiency. Sometimes we wish to extract all reads to a given
# AbstractFunction, and we know that by construction these can't
# appear among the implicit reads
return chain(self.reads_explicit_gen(), self.reads_implicit_gen())
return chain(self.reads_explicit_gen(),
self.reads_synchro_gen(),
self.reads_implicit_gen())

@memoized_generator
def reads_smart_gen(self, f):
Expand All @@ -939,7 +973,7 @@ def reads_smart_gen(self, f):
the iteration symbols.
"""
if isinstance(f, (Function, Temp, TempArray, TBArray)):
for i in self.reads_explicit_gen():
for i in chain(self.reads_explicit_gen(), self.reads_synchro_gen()):
if f is i.function:
for j in extrema(i.access):
yield TimedAccess(j, i.mode, i.timestamp, i.ispace)
Expand Down
8 changes: 7 additions & 1 deletion devito/passes/clusters/asynchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from devito.ir import (Forward, GuardBoundNext, Queue, Vector, WaitLock, WithLock,
FetchUpdate, PrefetchUpdate, ReleaseLock, normalize_syncs)
from devito.passes.clusters.utils import is_memcpy
from devito.passes.clusters.utils import bind_critical_regions, is_memcpy
from devito.symbolics import IntDiv, uxreplace
from devito.tools import OrderedSet, is_integer, timed_pass
from devito.types import CustomDimension, Lock
Expand Down Expand Up @@ -139,6 +139,12 @@ def callback(self, clusters, prefix):
tasks[c0].append(ReleaseLock(lock[i], target))
tasks[c0].append(WithLock(lock[i], target, i, function, findex, d))

# CriticalRegions preempt WaitLocks, by definition
mapper = bind_critical_regions(clusters)
for c in clusters:
for c1 in mapper.get(c, []):
waits[c].update(waits.pop(c1, []))

processed = []
for c in clusters:
if waits[c] or tasks[c]:
Expand Down
8 changes: 5 additions & 3 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass
from devito.ir.support import (SEQUENTIAL, SEPARABLE, Scope, ReleaseLock,
WaitLock, WithLock, FetchUpdate, PrefetchUpdate)
from devito.passes.clusters.utils import in_critical_region
from devito.symbolics import pow_to_mul
from devito.tools import DAG, Stamp, as_tuple, flatten, frozendict, timed_pass
from devito.types import Hyperplane
Expand Down Expand Up @@ -44,8 +45,9 @@ def callback(self, clusters, prefix):
processed.append(c)
continue

# Synchronization operations prevent lifting
if c.syncs.get(dim):
# Synchronization prevents lifting
if c.syncs.get(dim) or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this get too long as one line?

Copy link
Contributor Author

@FabioLuporini FabioLuporini Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do the splitting on purpose for clarity when I feel there's (relatively speaking) little correlation between the two operands of the relation

in_critical_region(c, clusters):
processed.append(c)
continue

Expand Down Expand Up @@ -262,7 +264,7 @@ def dump():

groups, processed = processed, []
for group in groups:
for flag, minigroup in groupby(group, key=lambda c: c.is_halo_touch):
for flag, minigroup in groupby(group, key=lambda c: c.is_wild):
if flag:
processed.extend([(c,) for c in minigroup])
else:
Expand Down
50 changes: 48 additions & 2 deletions devito/passes/clusters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from collections import defaultdict

from devito.ir import Cluster
from devito.symbolics import uxreplace
from devito.types import Symbol, Wildcard
from devito.tools import as_tuple, flatten
from devito.types import CriticalRegion, Eq, Symbol, Wildcard

__all__ = ['makeit_ssa', 'is_memcpy']
__all__ = ['makeit_ssa', 'is_memcpy', 'make_critical_sequence',
'bind_critical_regions', 'in_critical_region']


def makeit_ssa(exprs):
Expand Down Expand Up @@ -48,3 +53,44 @@ def is_memcpy(expr):
return False

return a.function.is_Array or b.function.is_Array


def make_critical_sequence(ispace, sequence, **kwargs):
sequence = as_tuple(sequence)
assert len(sequence) >= 1

processed = []

# Opening
expr = Eq(Symbol(name='⋈'), CriticalRegion(True))
processed.append(Cluster(exprs=expr, ispace=ispace, **kwargs))

processed.extend(sequence)

# Closing
expr = Eq(Symbol(name='⋈'), CriticalRegion(False))
processed.append(Cluster(exprs=expr, ispace=ispace, **kwargs))

return processed


def bind_critical_regions(clusters):
"""
A mapper from CriticalRegions to the critical sequences they open.
"""
critical_region = False
mapper = defaultdict(list)
for c in clusters:
if c.is_critical_region:
critical_region = not critical_region and c
elif critical_region:
mapper[critical_region].append(c)
return mapper


def in_critical_region(cluster, clusters):
"""
True if `cluster` is part of a critical sequence, False otherwise.
"""
mapper = bind_critical_regions(clusters)
return cluster in flatten(mapper.values())
Loading