Skip to content

Commit

Permalink
api: Support combination of condition and factor for ConditionalDimen…
Browse files Browse the repository at this point in the history
…sion
  • Loading branch information
mloubout committed Jul 17, 2024
1 parent 3054604 commit efe4344
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 4 deletions.
7 changes: 5 additions & 2 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,12 @@ def __new__(cls, *args, **kwargs):
if d.condition is None:
conditionals[d] = GuardFactor(d)
else:
conditionals[d] = diff2sympy(lower_exprs(d.condition))
cond = diff2sympy(lower_exprs(d.condition))
if d.factor is not None:
cond = sympy.And(cond, GuardFactor(d))
conditionals[d] = cond
if d.factor is not None:
expr = uxreplace(expr, {d: IntDiv(d.index, d.factor)})
expr = uxreplace(expr, {d: IntDiv(d.fact_index, d.factor)})
conditionals = frozendict(conditionals)

# Lower all Differentiable operations into SymPy operations
Expand Down
17 changes: 16 additions & 1 deletion devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from devito.types.args import ArgProvider
from devito.types.basic import Symbol, DataSymbol, Scalar
from devito.types.constant import Constant
from devito.types.relational import relational_min


__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
'CustomDimension', 'SteppingDimension', 'SubDimension',
Expand Down Expand Up @@ -872,6 +874,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
self._factor = factor
else:
raise ValueError("factor must be an integer or integer Constant")

self._condition = condition
self._indirect = indirect

Expand All @@ -892,6 +895,18 @@ def condition(self):
def indirect(self):
return self._indirect

@property
def fact_index(self):
if self.condition is None or self._factor is None:
return self.index

# This is the corner case where both a condition and a factor are provided
# the index will need to be `self.parent - min(self.condition)` to avoid
# shifted indexing. E.g if you have `factor=2` and `condition=Ge(time, 10)`
# then the lowered index needs to be `(time - 10)/ 2`
ltkn = relational_min(self.condition, self.parent)
return self.index - ltkn

@cached_property
def free_symbols(self):
retval = set(super().free_symbols)
Expand Down Expand Up @@ -927,7 +942,7 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
# `factor` endpoints are legal, so we return them all. It's then
# up to the caller to decide which one to pick upon reduction
dim = alias or self
if dim._factor is None or size is None:
if dim.condition is not None or size is None:
return defaults
try:
# Is it a symbolic factor?
Expand Down
43 changes: 42 additions & 1 deletion devito/types/relational.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""User API to specify relationals."""
from functools import singledispatch

import sympy

__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne']
__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne', 'relational_min']


class AbstractRel:
Expand Down Expand Up @@ -208,3 +209,43 @@ def __new__(cls, lhs, rhs=0, subdomain=None, **kwargs):

ops = {Ge: Lt, Gt: Le, Le: Gt, Lt: Ge}
rev = {Ge: Le, Gt: Lt, Lt: Gt, Le: Ge}


def relational_min(expr, s):
"""
Infer the minimum valid value for symbol `s` in the expression `expr`.
For example
- if `expr` is `s < 10`, then the minimum valid value for `s` is 0
- if `expr` is `s >= 10`, then the minimum valid value for `s` is 10
"""
assert expr.has(s), "Symbol %s not found in expression %s" % (s, expr)

return _relational_min(expr, s)


@singledispatch
def _relational_min(s, expr):
return 0


@_relational_min.register(sympy.And)
def _(expr, s):
return max([_relational_min(e, s) for e in expr.args])


@_relational_min.register(Gt)
@_relational_min.register(Lt)
def _(expr, s):
if s == expr.gts:
return expr.lts + 1
else:
return 0


@_relational_min.register(Ge)
@_relational_min.register(Le)
def _(expr, s):
if s == expr.gts:
return expr.lts
else:
return 0
29 changes: 29 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,35 @@ def test_issue_1753(self):
op.apply(time_M=1)
assert np.all(np.flatnonzero(f.data) == [3, 30])

def test_issue_2273(self):
grid = Grid(shape=(11, 11))
time = grid.time_dim

nt = 200
bounds = (10, 100)
factor = 5

condition = And(Ge(time, bounds[0]), Le(time, bounds[1]))

time_under = ConditionalDimension(name='timeu', parent=time,
factor=factor, condition=condition)
buffer_size = (bounds[1] - bounds[0] + factor) // factor + 1

rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt,
coordinates=[(.5, .5)])
rec.data[:] = 1.0

u = TimeFunction(name='u', grid=grid, space_order=2)
usaved = TimeFunction(name='usaved', grid=grid, space_order=2,
time_dim=time_under, save=buffer_size)

eq = [Eq(u.forward, u)] + rec.inject(field=u.forward, expr=rec) + [Eq(usaved, u)]

op = Operator(eq)
op(time_m=0, time_M=nt-1)
expected = np.linspace(bounds[0], bounds[1], num=buffer_size-1)
assert np.allclose(usaved.data[:-1, 5, 5], expected)

def test_subsampled_fd(self):
"""
Test that the FD shortcuts are handled correctly with ConditionalDimensions
Expand Down

0 comments on commit efe4344

Please sign in to comment.