diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 74754bad96f..eae9dbd9ff9 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -379,22 +379,22 @@ def dspace(self): # Dimension-centric view of the data space intervals = IntervalGroup.generate('union', *parts.values()) + # 'union' may consume intervals (values) from keys that have dimensions + # not mapped to intervals e.g. issue #2235, resulting in reduced + # iteration size. Here, we relax this mapped upper interval, by + # intersecting intervals with matching only dimensions + for f, v in parts.items(): + for i in v: + # oobs check is not required but helps reduce + # interval reconstruction + if i.dim in oobs and i.dim in f.dimensions: + ii = intervals[i.dim].intersection(v[i.dim]) + intervals = intervals.set_upper(i.dim, ii.upper) + # E.g., `db0 -> time`, but `xi NOT-> x` intervals = intervals.promote(lambda d: not d.is_Sub) intervals = intervals.zero(set(intervals.dimensions) - oobs) - # Upper bound of intervals including dimensions classified for - # shifting should retain the "oobs" upper bound - for f, v in parts.items(): - for i in v: - if i.dim in oobs: - try: - if intervals[i.dim].upper > v[i.dim].upper and \ - bool(i.dim in f.dimensions): - intervals = intervals.ceil(v[i.dim]) - except AttributeError: - pass - return DataSpace(intervals, parts) @cached_property diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index c7be57b8d83..573f5e7ba9e 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -258,10 +258,8 @@ def negate(self): def zero(self): return Interval(self.dim, 0, 0, self.stamp) - def ceil(self, o): - if not self.is_compatible(o): - return self._rebuild() - return Interval(self.dim, self.lower, o.upper, self.stamp) + def set_upper(self, v=0): + return Interval(self.dim, self.lower, v, self.stamp) def flip(self): return Interval(self.dim, self.upper, self.lower, self.stamp) @@ -496,9 +494,9 @@ def zero(self, d=None): return IntervalGroup(intervals, relations=self.relations, mode=self.mode) - def ceil(self, o=None): - d = self.dimensions if o is None else as_tuple(o.dim) - return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self], + def set_upper(self, d, v=0): + dims = as_tuple(d) + return IntervalGroup([i.set_upper(v) if i.dim in dims else i for i in self], relations=self.relations, mode=self.mode) def lift(self, d=None, v=None): diff --git a/tests/test_operator.py b/tests/test_operator.py index fb0aaafae80..9d245664682 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1992,7 +1992,7 @@ class TestInternals: @pytest.mark.parametrize('nt, offset, epass', ([1, 1, True], [1, 2, False], - [5, 1, True], [3, 5, False], + [5, 3, True], [3, 5, False], [4, 1, True], [5, 10, False])) def test_indirection(self, nt, offset, epass): grid = Grid(shape=(4, 4))