diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 30a746b731..2e3699fe5e 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -297,9 +297,9 @@ def dspace(self): continue intervals = [Interval(d, - min([minimum(i) for i in offs]), - max([maximum(i) for i in offs])) - for d, offs in v.items()] + min([minimum(i, ispace=self.ispace) for i in o]), + max([maximum(i, ispace=self.ispace) for i in o])) + for d, o in v.items()] intervals = IntervalGroup(intervals) # Factor in the IterationSpace -- if the min/max points aren't zero, diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index acb851dbe7..ba6cdec84d 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -308,30 +308,42 @@ def _relational(expr, callback, udims=None): return expr.subs(mapper) -def minimum(expr, udims=None): +def minimum(expr, udims=None, ispace=None): """ Substitute the unbounded Dimensions in `expr` with their minimum point. Unbounded Dimensions whose possible minimum value is not known are ignored. """ - return _relational(expr, lambda e: e._min, udims) + def callback(sd): + try: + return sd._min + ispace[sd].lower + except (TypeError, KeyError): + return sd._min + + return _relational(expr, callback, udims) -def maximum(expr, udims=None): +def maximum(expr, udims=None, ispace=None): """ Substitute the unbounded Dimensions in `expr` with their maximum point. Unbounded Dimensions whose possible maximum value is not known are ignored. """ - return _relational(expr, lambda e: e._max, udims) + def callback(sd): + try: + return sd._max + ispace[sd].upper + except (TypeError, KeyError): + return sd._max + + return _relational(expr, callback, udims) -def extrema(expr): +def extrema(expr, ispace=None): """ The minimum and maximum extrema assumed by `expr` once the unbounded Dimensions are resolved. """ - return Extrema(minimum(expr), maximum(expr)) + return Extrema(minimum(expr, ispace=ispace), maximum(expr, ispace=ispace)) def minmax_index(expr, d): diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 77d114ad17..4992657aaf 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -115,6 +115,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry): extra = (c.ispace.itdims + dims,) ispace = IterationSpace.union(c.ispace, ispace0, relations=extra) + # Set the IterationSpace along the StencilDimensions to start from 0 + # (rather than the default `d._min`) to minimize the amount of integer + # arithmetic to calculate the various index access functions + for d in dims: + ispace = ispace.translate(d, -d._min) + try: s = reusables.pop() assert s.dtype is w.dtype @@ -125,12 +131,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry): ispace1 = ispace.project(lambda d: d is not dims[-1]) processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1)) - # Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the - # StencilDimensions starting points - subs = {expr.weights: - expr.weights.subs(d, d - d._min) - for d in dims} - expr1 = Inc(s, uxreplace(expr.expr, subs)) + # Transform e.g. `r0[x + i0 + 2, y] -> r0[x + i0, y, z]` for alignment + # with the shifted `ispace` + base = expr.base + for d in dims: + base = base.subs(d, d + d._min) + expr1 = Inc(s, base*expr.weights) processed.append(c.rebuild(exprs=expr1, ispace=ispace)) # Track lowered IndexDerivative for subsequent optimization by the caller diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 05ba84f0bd..186f806926 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -93,7 +93,7 @@ def test_multiple_cross_derivs(self, coeffs, expected): # w0, w1, ... functions = FindSymbols().visit(op) - weights = [f for f in functions if isinstance(f, Weights)] + weights = {f for f in functions if isinstance(f, Weights)} assert len(weights) == expected