Skip to content

Commit

Permalink
Merge pull request #1235 from devitocodes/fix-aliases-scheduling
Browse files Browse the repository at this point in the history
Fix scheduling of CIRE-detected aliasing expressions
  • Loading branch information
mloubout authored Apr 16, 2020
2 parents a5b05e6 + 6be23ec commit 0ee4fac
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 115 deletions.
5 changes: 1 addition & 4 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from devito.exceptions import InvalidOperator
from devito.ir.clusters import Toposort
from devito.passes.clusters import (Blocking, Lift, cire, cse, eliminate_arrays,
extract_increments, factorize, fuse, optimize_pows,
scalarize)
extract_increments, factorize, fuse, optimize_pows)
from devito.passes.iet import (DataManager, Ompizer, avoid_denormals, mpiize,
optimize_halospots, loop_wrapping, hoist_prodders,
relax_incr_dimensions)
Expand Down Expand Up @@ -111,7 +110,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
# turn may enable further optimizations
clusters = fuse(clusters)
clusters = eliminate_arrays(clusters, template)
clusters = scalarize(clusters, template)

return clusters

Expand Down Expand Up @@ -225,7 +223,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
# turn may enable further optimizations
clusters = fuse(clusters)
clusters = eliminate_arrays(clusters, template)
clusters = scalarize(clusters, template)

# Blocking to improve data locality
clusters = Blocking(options).process(clusters)
Expand Down
3 changes: 1 addition & 2 deletions devito/core/gpu_openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from devito.logger import warning
from devito.mpi.routines import CopyBuffer, SendRecv, HaloUpdate
from devito.passes.clusters import (Lift, cire, cse, eliminate_arrays, extract_increments,
factorize, fuse, optimize_pows, scalarize)
factorize, fuse, optimize_pows)
from devito.passes.iet import (DataManager, Storage, Ompizer, ParallelIteration,
ParallelTree, optimize_halospots, mpiize, hoist_prodders,
iet_pass)
Expand Down Expand Up @@ -278,7 +278,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
# further optimizations
clusters = fuse(clusters)
clusters = eliminate_arrays(clusters, template)
clusters = scalarize(clusters, template)

return clusters

Expand Down
31 changes: 31 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def _has(self, pattern):

class DifferentiableOp(Differentiable):

__sympy_class__ = None

def __new__(cls, *args, **kwargs):
obj = cls.__base__.__new__(cls, *args, **kwargs)

Expand Down Expand Up @@ -294,18 +296,22 @@ def _eval_is_zero(self):


class Add(DifferentiableOp, sympy.Add):
__sympy_class__ = sympy.Add
__new__ = DifferentiableOp.__new__


class Mul(DifferentiableOp, sympy.Mul):
__sympy_class__ = sympy.Mul
__new__ = DifferentiableOp.__new__


class Pow(DifferentiableOp, sympy.Pow):
__sympy_class__ = sympy.Pow
__new__ = DifferentiableOp.__new__


class Mod(DifferentiableOp, sympy.Mod):
__sympy_class__ = sympy.Mod
__new__ = DifferentiableOp.__new__


Expand Down Expand Up @@ -369,6 +375,31 @@ def _(obj):
return obj.__class__


def diff2sympy(expr):
"""
Translate a Differentiable expression into a SymPy expression.
"""

def _diff2sympy(obj):
flag = False
args = []
for a in obj.args:
ax, af = _diff2sympy(a)
args.append(ax)
flag |= af
try:
return obj.__sympy_class__(*args, evaluate=False), True
except AttributeError:
# Not of type DifferentiableOp
pass
if flag:
return obj.func(*args, evaluate=False), True
else:
return obj, False

return _diff2sympy(expr)[0]


# Make sure `sympy.evalf` knows how to evaluate the inherited classes
# Without these, `evalf` would rely on a much slower, much more generic, and
# thus much more time-inefficient fallback routine. This would hit us
Expand Down
6 changes: 5 additions & 1 deletion devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sympy

from devito.ir.equations.algorithms import dimension_sort
from devito.finite_differences.differentiable import diff2sympy
from devito.ir.support import (IterationSpace, DataSpace, Interval, IntervalGroup,
Stencil, detect_accesses, detect_oobs, detect_io,
build_intervals, build_iterators)
Expand Down Expand Up @@ -147,8 +148,11 @@ def __new__(cls, *args, **kwargs):
for k, v in mapper.items() if k}
dspace = DataSpace(dintervals, parts)

# Lower all Differentiable operations into SymPy operations
rhs = diff2sympy(expr.rhs)

# Finally create the LoweredEq with all metadata attached
expr = super(LoweredEq, cls).__new__(cls, expr.lhs, expr.rhs, evaluate=False)
expr = super(LoweredEq, cls).__new__(cls, expr.lhs, rhs, evaluate=False)

expr._dspace = dspace
expr._ispace = ispace
Expand Down
17 changes: 10 additions & 7 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,10 @@ def zero(self):
def flip(self):
return Interval(self.dim, self.upper, self.lower, self.stamp)

def lift(self):
return Interval(self.dim, self.lower, self.upper, self.stamp + 1)
def lift(self, v=None):
if v is None:
v = self.stamp + 1
return Interval(self.dim, self.lower, self.upper, v)

def reset(self):
return Interval(self.dim, self.lower, self.upper, 0)
Expand Down Expand Up @@ -373,9 +375,9 @@ def zero(self, d=None):
return IntervalGroup([i.zero() if i.dim in d else i for i in self],
relations=self.relations)

def lift(self, d):
def lift(self, d, v=None):
d = set(self.dimensions if d is None else as_tuple(d))
return IntervalGroup([i.lift() if i.dim._defines & d else i for i in self],
return IntervalGroup([i.lift(v) if i.dim._defines & d else i for i in self],
relations=self.relations)

def reset(self):
Expand Down Expand Up @@ -700,13 +702,14 @@ def project(self, cond):
func = lambda i: i in cond

intervals = [i for i in self.intervals if func(i.dim)]

sub_iterators = {k: v for k, v in self.sub_iterators.items() if func(k)}

directions = {k: v for k, v in self.directions.items() if func(k)}

return IterationSpace(intervals, sub_iterators, directions)

def lift(self, d=None, v=None):
intervals = self.intervals.lift(d, v)
return IterationSpace(intervals, self.sub_iterators, self.directions)

def is_compatible(self, other):
"""
A relaxed version of ``__eq__``, in which only non-derived dimensions
Expand Down
79 changes: 43 additions & 36 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from cached_property import cached_property
import numpy as np

from devito.ir import (ROUNDABLE, DataSpace, IterationInstance, Interval,
IntervalGroup, LabeledVector, detect_accesses, build_intervals)
from devito.ir import (ROUNDABLE, DataSpace, IterationInstance, Interval, IntervalGroup,
LabeledVector, Scope, detect_accesses, build_intervals)
from devito.passes.clusters.utils import cluster_pass, make_is_time_invariant
from devito.symbolics import (compare_ops, estimate_cost, q_constant, q_leaf,
q_sum_of_product, q_terminalop, retrieve_indexed,
Expand Down Expand Up @@ -88,42 +88,50 @@ def cire(cluster, template, mode, options, platform):
min_storage = options['min-storage']

# Setup callbacks
if mode == 'invariants':
# Extraction rule
def extractor(context):
return make_is_time_invariant(context)

# Extraction model
def callbacks_invariants(context, *args):
extractor = make_is_time_invariant(context)
model = lambda e: estimate_cost(e, True) >= MIN_COST_ALIAS_INV

# Collection rule
ignore_collected = lambda g: False

# Selection rule
selector = lambda c, n: c >= MIN_COST_ALIAS_INV and n >= 1

elif mode == 'sops':
# Extraction rule
def extractor(context):
return q_sum_of_product

# Extraction model
model = lambda e: not (q_leaf(e) or q_terminalop(e))

# Collection rule
return extractor, model, ignore_collected, selector

def callbacks_sops(context, n):
# The `depth` determines "how big" the extracted sum-of-products will be.
# We observe that in typical FD codes:
# add(mul, mul, ...) -> stems from first order derivative
# add(mul(add(mul, mul, ...), ...), ...) -> stems from second order derivative
# To catch the former, we would need `depth=1`; for the latter, `depth=3`
depth = 2*n + 1

extractor = lambda e: q_sum_of_product(e, depth)
model = lambda e: not (q_leaf(e) or q_terminalop(e, depth-1))
ignore_collected = lambda g: len(g) <= 1

# Selection rule
selector = lambda c, n: c >= MIN_COST_ALIAS and n > 1
return extractor, model, ignore_collected, selector

callbacks_mapper = {
'invariants': callbacks_invariants,
'sops': callbacks_sops
}

# Actual CIRE
# The main CIRE loop
processed = []
context = cluster.exprs
for _ in range(options['cire-repeats'][mode]):
for n in reversed(range(options['cire-repeats'][mode])):
# Get the callbacks
extractor, model, ignore_collected, selector = callbacks_mapper[mode](context, n)

# Extract potentially aliasing expressions
exprs, extracted = extract(cluster, extractor(context), model, template)
exprs, extracted = extract(cluster, extractor, model, template)
if not extracted:
# Do not waste time
continue

# There can't be Dimension-dependent data dependences with any of
# the `processed` Clusters, otherwise we would risk either OOB accesses
# or reading from garbage uncomputed halo
scope = Scope(exprs=flatten(c.exprs for c in processed) + extracted)
if not all(i.is_indep() for i in scope.d_all_gen()):
break

# Search aliasing expressions
Expand All @@ -133,7 +141,7 @@ def extractor(context):
chosen, others = choose(exprs, aliases, selector)
if not chosen:
# Do not waste time
break
continue

# Create Aliases and assign them to Clusters
clusters, subs = process(cluster, chosen, aliases, template, platform)
Expand Down Expand Up @@ -341,7 +349,7 @@ def choose(exprs, aliases, selector):
def process(cluster, chosen, aliases, template, platform):
clusters = []
subs = {}
for alias, writeto, aliaseds, distances in aliases.schedule(cluster.ispace):
for alias, writeto, aliaseds, distances in aliases.iter(cluster.ispace):
if all(i not in chosen for i in aliaseds):
continue

Expand Down Expand Up @@ -412,7 +420,7 @@ def process(cluster, chosen, aliases, template, platform):

# Finally, build a new Cluster for `alias`
built = cluster.rebuild(exprs=expression, ispace=ispace, dspace=dspace)
clusters.insert(0, built)
clusters.append(built)

return clusters, subs

Expand All @@ -433,6 +441,9 @@ def rebuild(cluster, others, aliases, subs):
return cluster.rebuild(exprs=exprs, ispace=ispace, dspace=dspace)


# Utilities


class Candidate(object):

def __init__(self, expr, indexeds, bases, offsets):
Expand Down Expand Up @@ -488,9 +499,6 @@ def dimensions(self):
return frozenset(i for i, _ in self.Toffsets)


# Utilities


class Group(tuple):

"""
Expand Down Expand Up @@ -688,7 +696,7 @@ def get(self, key):
return aliaseds
return []

def schedule(self, ispace):
def iter(self, ispace):
"""
The aliases can be be scheduled in any order, but we privilege the one
that minimizes storage while maximizing fusion.
Expand All @@ -710,8 +718,7 @@ def schedule(self, ispace):
# use `<1>` which is the actual stamp used in the Cluster
# from which the aliasing expressions were extracted
assert i.stamp >= interval.stamp
while interval.stamp != i.stamp:
interval = interval.lift()
interval = interval.lift(i.stamp)

writeto.append(interval)

Expand Down
4 changes: 2 additions & 2 deletions devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def collect_const(expr):

terms = []
for k, v in inverse_mapper.items():
if len(v) == 1:
# We can actually evaluate everything to avoid, e.g., (-1)*a
if len(v) == 1 and not v[0].is_Add:
# Special case: avoid e.g. (-2)*a
mul = Mul(k, *v)
elif all(i.is_Mul and len(i.args) == 2 and i.args[0] == -1 for i in v):
# Other special case: [-a, -b, -c ...]
Expand Down
47 changes: 2 additions & 45 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from devito.ir.clusters import Cluster, Queue
from devito.ir.support import TILABLE
from devito.passes.clusters.utils import cluster_pass
from devito.symbolics import pow_to_mul, xreplace_indices, uxreplace
from devito.symbolics import pow_to_mul, uxreplace
from devito.tools import filter_ordered, timed_pass
from devito.types import Scalar

__all__ = ['Lift', 'fuse', 'scalarize', 'eliminate_arrays', 'optimize_pows',
'extract_increments']
__all__ = ['Lift', 'fuse', 'eliminate_arrays', 'optimize_pows', 'extract_increments']


class Lift(Queue):
Expand Down Expand Up @@ -99,48 +98,6 @@ def fuse(clusters):
return processed


@timed_pass()
def scalarize(clusters, template):
"""
Turn local "isolated" Arrays, that is Arrays appearing only in one Cluster,
into Scalars.
"""
processed = []
for c in clusters:
# Get any Arrays appearing only in `c`
impacted = set(clusters) - {c}
arrays = {i for i in c.scope.writes if i.is_Array}
arrays -= set().union(*[i.scope.reads for i in impacted])

# Turn them into scalars
#
# r[x,y,z] = g(b[x,y,z]) t0 = g(b[x,y,z])
# ... = r[x,y,z] + r[x,y,z+1]` ----> t1 = g(b[x,y,z+1])
# ... = t0 + t1
mapper = {}
exprs = []
for n, e in enumerate(c.exprs):
f = e.lhs.function
if f in arrays:
indexeds = [i.indexed for i in c.scope[f] if i.timestamp > n]
for i in filter_ordered(indexeds):
mapper[i] = Scalar(name=template(), dtype=f.dtype)

assert len(f.indices) == len(e.lhs.indices) == len(i.indices)
shifting = {idx: idx + (o2 - o1) for idx, o1, o2 in
zip(f.indices, e.lhs.indices, i.indices)}

handle = e.func(mapper[i], uxreplace(e.rhs, mapper))
handle = xreplace_indices(handle, shifting)
exprs.append(handle)
else:
exprs.append(e.func(e.lhs, uxreplace(e.rhs, mapper)))

processed.append(c.rebuild(exprs))

return processed


@timed_pass()
def eliminate_arrays(clusters, template):
"""
Expand Down
Loading

0 comments on commit 0ee4fac

Please sign in to comment.