Skip to content

Commit

Permalink
Merge pull request #2410 from devitocodes/cse-conds
Browse files Browse the repository at this point in the history
compiler: fix cse with different conditionals
  • Loading branch information
mloubout authored Jul 16, 2024
2 parents 6e79f1c + a9a4d14 commit 3054604
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 11 deletions.
2 changes: 1 addition & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def detect_accesses(exprs):
other_dims = set()
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.implicit_dims)
other_dims.update(e.implicit_dims or {})
other_dims = filter_sorted(other_dims)
mapper[None] = Stencil([(i, 0) for i in other_dims])

Expand Down
32 changes: 24 additions & 8 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import Counter, OrderedDict
from collections import Counter, OrderedDict, namedtuple
from functools import singledispatch

import sympy
from sympy import Add, Function, Indexed, Mul, Pow
try:
from sympy.core.core import ordering_of_classes
Expand All @@ -13,12 +14,15 @@
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.tools import as_list, frozendict
from devito.types import Eq, Symbol, Temp

__all__ = ['cse']


Counted = namedtuple('Candidate', 'expr, conditionals')


class CTemp(Temp):

"""
Expand Down Expand Up @@ -91,12 +95,11 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):
while True:
# Detect redundancies
counted = count(processed).items()
targets = OrderedDict([(k, estimate_cost(k, True)) for k, v in counted if v > 1])

targets = OrderedDict([(k, estimate_cost(k.expr, True))
for k, v in counted if v > 1])
# Rule out Dimension-independent data dependencies
targets = OrderedDict([(k, v) for k, v in targets.items()
if not k.free_symbols & exclude])

if not k.expr.free_symbols & exclude])
if not targets or max(targets.values()) < min_cost:
break

Expand All @@ -112,9 +115,11 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):
for e in processed:
pe = e
for k, v in chosen:
pe, changed = _uxreplace(pe, {k: v})
if not k.conditionals == e.conditionals:
continue
pe, changed = _uxreplace(pe, {k.expr: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
updated.append(pe.func(v, k.expr, operation=None))
scheduled.append(v)
updated.append(pe)
processed = updated
Expand Down Expand Up @@ -172,9 +177,20 @@ def _(exprs):
mapper = Counter()
for e in exprs:
mapper.update(count(e))

return mapper


@count.register(sympy.Eq)
def _(expr):
mapper = count(expr.rhs)
try:
cond = expr.conditionals
except AttributeError:
cond = frozendict()
return {Counted(e, cond): v for e, v in mapper.items()}


@count.register(Indexed)
@count.register(Symbol)
def _(expr):
Expand Down
6 changes: 5 additions & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import cached_property

from devito.finite_differences import default_rules
from devito.tools import as_tuple
from devito.tools import as_tuple, frozendict
from devito.types.lazy import Evaluable

__all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin']
Expand Down Expand Up @@ -137,6 +137,10 @@ def substitutions(self):
def implicit_dims(self):
return self._implicit_dims

@property
def conditionals(self):
return frozendict()

@cached_property
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, Ge)
sin, sqrt, Ge, Lt)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.finite_differences.differentiable import diffify
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
Expand Down Expand Up @@ -191,6 +191,55 @@ def test_cse_w_conditionals():
assert len(FindNodes(Conditional).visit(op)) == 1


def test_cse_w_multi_conditionals():
grid = Grid(shape=(10, 10, 10))
x, _, _ = grid.dimensions

cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4),
indirect=True)

cd2 = ConditionalDimension(name='cd2', parent=x, condition=Lt(x, 4),
indirect=True)

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
a0 = Function(name='a0', grid=grid)
a1 = Function(name='a1', grid=grid)
a2 = Function(name='a2', grid=grid)
a3 = Function(name='a3', grid=grid)

eq0 = Eq(h, a0, implicit_dims=cd)
eq1 = Eq(a0, a0 + f*g, implicit_dims=cd)
eq2 = Eq(a1, a1 + f*g, implicit_dims=cd)
eq3 = Eq(a2, a2 + f*g, implicit_dims=cd2)
eq4 = Eq(a3, a3 + f*g, implicit_dims=cd2)

op = Operator([eq0, eq1, eq3])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 0

op = Operator([eq0, eq1, eq3, eq4])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 1

op = Operator([eq0, eq1, eq2, eq3, eq4])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 2


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down

0 comments on commit 3054604

Please sign in to comment.