Skip to content

Commit

Permalink
compiler: Patch CSE in presence of conditionals
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini authored and mloubout committed Jul 1, 2024
1 parent 7f77489 commit 90ad3e4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
14 changes: 8 additions & 6 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,26 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):

# Create temporaries
hit = max(targets.values())
temps = [Eq(make(), k) for k, v in targets.items() if v == hit]
chosen = [(k, make()) for k, v in targets.items() if v == hit]

# Apply replacements
# The extracted temporaries are inserted before the first expression
# that contains it
scheduled = []
updated = []
for e in processed:
pe = e
for t in temps:
pe, changed = _uxreplace(pe, {t.rhs: t.lhs})
if changed and t not in updated:
updated.append(t)
for k, v in chosen:
pe, changed = _uxreplace(pe, {k: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
scheduled.append(v)
updated.append(pe)
processed = updated

# Update `exclude` for the same reasons as above -- to rule out CSE across
# Dimension-independent data dependences
exclude.update({t.lhs for t in temps})
exclude.update(scheduled)

# At this point we may have useless temporaries (e.g., r0=r1). Let's drop them
processed = _compact_temporaries(processed, exclude)
Expand Down
25 changes: 24 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt)
sin, sqrt, Ge)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.finite_differences.differentiable import diffify
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
Expand Down Expand Up @@ -168,6 +168,29 @@ def test_cse_temp_order():
assert type(args[2]) is CTemp


def test_cse_w_conditionals():
grid = Grid(shape=(10, 10, 10))
x, _, _ = grid.dimensions

cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4),
indirect=True)

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
a0 = Function(name='a0', grid=grid)
a1 = Function(name='a1', grid=grid)

eqns = [Eq(h, a0, implicit_dims=cd),
Eq(a0, a0 + f*g, implicit_dims=cd),
Eq(a1, a1 + f*g, implicit_dims=cd)]

op = Operator(eqns)

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 1


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down

0 comments on commit 90ad3e4

Please sign in to comment.