From 2fa4880208e720059dc38d529a18dfdb4a1a4725 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Fri, 24 Nov 2023 14:41:01 +0000 Subject: [PATCH] compiler: Relax intervals with upper from not mapped dimensions --- devito/ir/clusters/cluster.py | 22 ++++++++++------------ devito/ir/support/space.py | 12 +++++------- tests/test_operator.py | 2 +- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 7a0ab4e42e7..a35513708ae 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -358,22 +358,20 @@ def dspace(self): # Dimension-centric view of the data space intervals = IntervalGroup.generate('union', *parts.values()) + # 'union' may have resulted in intervals stricter than needed + # e.g. issue #2235. We relax the upper interval with the upper + # from not mapped parts + for f, v in parts.items(): + for d in f.dimensions: + # oobs check is not required but helps reduce + # interval reconstruction + if d in oobs and not v[d].is_Null: + intervals = intervals.set_upper(d, v[d].upper) + # E.g., `db0 -> time`, but `xi NOT-> x` intervals = intervals.promote(lambda d: not d.is_Sub) intervals = intervals.zero(set(intervals.dimensions) - oobs) - # Upper bound of intervals including dimensions classified for - # shifting should retain the "oobs" upper bound - for f, v in parts.items(): - for i in v: - if i.dim in oobs: - try: - if intervals[i.dim].upper > v[i.dim].upper and \ - bool(i.dim in f.dimensions): - intervals = intervals.ceil(v[i.dim]) - except AttributeError: - pass - return DataSpace(intervals, parts) @cached_property diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 310657b1afc..c37c2a26c8b 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -259,10 +259,8 @@ def negate(self): def zero(self): return Interval(self.dim, 0, 0, self.stamp) - def ceil(self, o): - if not self.is_compatible(o): - return self._rebuild() - return Interval(self.dim, self.lower, o.upper, self.stamp) + def set_upper(self, v=0): + return Interval(self.dim, self.lower, v, self.stamp) def flip(self): return Interval(self.dim, self.upper, self.lower, self.stamp) @@ -497,9 +495,9 @@ def zero(self, d=None): return IntervalGroup(intervals, relations=self.relations, mode=self.mode) - def ceil(self, o=None): - d = self.dimensions if o is None else as_tuple(o.dim) - return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self], + def set_upper(self, d, v=0): + dims = as_tuple(d) + return IntervalGroup([i.set_upper(v) if i.dim in dims else i for i in self], relations=self.relations, mode=self.mode) def lift(self, d=None, v=None): diff --git a/tests/test_operator.py b/tests/test_operator.py index 47e4d85bada..7972ec4d50f 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1993,7 +1993,7 @@ class TestInternals(object): @pytest.mark.parametrize('nt, offset, epass', ([1, 1, True], [1, 2, False], - [5, 1, True], [3, 5, False], + [5, 3, True], [3, 5, False], [4, 1, True], [5, 10, False])) def test_indirection(self, nt, offset, epass): grid = Grid(shape=(4, 4))