Skip to content

Commit

Permalink
Merge pull request #1421 from devitocodes/fix-issue-1298
Browse files Browse the repository at this point in the history
Fix issue #1298
  • Loading branch information
mloubout authored Aug 5, 2020
2 parents dfabbf5 + 9859bed commit caac6b8
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
21 changes: 11 additions & 10 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from devito.ir.clusters.analysis import analyze
from devito.ir.clusters.cluster import Cluster, ClusterGroup
from devito.ir.clusters.queue import QueueStateful
from devito.ir.equations.algorithms import lower_exprs
from devito.symbolics import CondEq
from devito.tools import timed_pass

__all__ = ['clusterize']
Expand Down Expand Up @@ -165,23 +163,26 @@ def guard(clusters):
processed = []
for c in clusters:
# Group together consecutive expressions with same ConditionalDimensions
for cds, g in groupby(c.exprs, key=lambda e: e.conditionals):
for cds, g in groupby(c.exprs, key=lambda e: tuple(e.conditionals)):
exprs = list(g)

if not cds:
processed.append(c.rebuild(exprs=exprs))
continue

# Create a guarded Cluster
# Chain together all conditions from all expressions in `c`
guards = {}
for cd in cds:
condition = guards.setdefault(cd.parent, [])
if cd.condition is None:
condition.append(CondEq(cd.parent % cd.factor, 0))
else:
condition.append(lower_exprs(cd.condition))
guards = {k: sympy.And(*v, evaluate=False) for k, v in guards.items()}
exprs = [e.func(*e.args, conditionals=dict(guards)) for e in exprs]
for e in exprs:
try:
condition.append(e.conditionals[cd])
break
except KeyError:
pass
guards = {d: sympy.And(*v, evaluate=False) for d, v in guards.items()}

# Construct a guarded Cluster
processed.append(c.rebuild(exprs=exprs, guards=guards))

return ClusterGroup(processed)
18 changes: 14 additions & 4 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from cached_property import cached_property
import sympy

from devito.ir.equations.algorithms import dimension_sort
from devito.ir.equations.algorithms import dimension_sort, lower_exprs
from devito.finite_differences.differentiable import diff2sympy
from devito.ir.support import (IterationSpace, DataSpace, Interval, IntervalGroup,
Stencil, detect_accesses, detect_oobs, detect_io,
build_intervals, build_iterators)
from devito.symbolics import CondEq
from devito.tools import Pickable, frozendict
from devito.types import Eq

Expand Down Expand Up @@ -122,7 +123,7 @@ def __new__(cls, *args, **kwargs):
# Analyze the expression
mapper = detect_accesses(expr)
oobs = detect_oobs(mapper)
conditionals = [i for i in ordering if i.is_Conditional]
conditional_dimensions = [i for i in ordering if i.is_Conditional]

# Construct Intervals for IterationSpace and DataSpace
intervals = build_intervals(Stencil.union(*mapper.values()))
Expand All @@ -144,11 +145,20 @@ def __new__(cls, *args, **kwargs):

# Construct the DataSpace
dintervals.extend([Interval(i, 0, 0) for i in ordering
if i not in ispace.dimensions + conditionals])
if i not in ispace.dimensions + conditional_dimensions])
parts = {k: IntervalGroup(build_intervals(v)).add(iintervals)
for k, v in mapper.items() if k}
dspace = DataSpace(dintervals, parts)

# Construct the conditionals
conditionals = {}
for d in conditional_dimensions:
if d.condition is None:
conditionals[d] = CondEq(d.parent % d.factor, 0)
else:
conditionals[d] = lower_exprs(d.condition)
conditionals = frozendict(conditionals)

# Lower all Differentiable operations into SymPy operations
rhs = diff2sympy(expr.rhs)

Expand All @@ -157,7 +167,7 @@ def __new__(cls, *args, **kwargs):

expr._dspace = dspace
expr._ispace = ispace
expr._conditionals = frozendict([(d, ()) for d in conditionals])
expr._conditionals = conditionals
expr._reads, expr._writes = detect_io(expr)

expr._is_Increment = input_expr.is_Increment
Expand Down
5 changes: 2 additions & 3 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,8 @@ def __init__(self, exprs, rules=None):
self.initialized.add(e.lhs.function)

# Look up ConditionalDimensions
for d, v in e.conditionals.items():
symbols = d.free_symbols | set(retrieve_terminals(v))
for j in symbols:
for v in e.conditionals.values():
for j in retrieve_terminals(v):
v = self.reads.setdefault(j.function, [])
v.append(TimedAccess(j, 'R', -1, e.ispace))

Expand Down
3 changes: 2 additions & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,15 @@ def detect_io(exprs, relax=False):
else:
rule = lambda i: i.is_Scalar or i.is_Tensor

# Don't forget this nasty case, with indirections on the LHS:
# Don't forget the nasty case with indirections on the LHS:
# >>> u[t, a[x]] = f[x] -> (reads={a, f}, writes={u})

roots = []
for i in exprs:
try:
roots.append(i.rhs)
roots.extend(list(i.lhs.indices))
roots.extend(list(i.conditionals.values()))
except AttributeError:
# E.g., FunctionFromPointer
roots.append(i)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,24 @@ def define(self, dimensions):
assert np.all(f.data[2:6, c1:c2] == 5.)
assert np.all(f.data[:, c3:c4] < 5.)

def test_from_cond_to_param(self):
"""
Test that Functions appearing in the condition of a ConditionalDimension
but not explicitly in an Eq are actually part of the Operator input
(stems from issue #1298).
"""
grid = Grid(shape=(8, 8))
x, y = grid.dimensions

g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
ci = ConditionalDimension(name='ci', parent=y, condition=Lt(g, 2 + h))
f = Function(name='f', shape=grid.shape, dimensions=(x, ci))

for _ in range(5):
# issue #1298 was non deterministic
Operator(Eq(f, 5)).apply()

@skipif('device')
def test_no_fusion_simple(self):
"""
Expand Down

0 comments on commit caac6b8

Please sign in to comment.