Skip to content

Commit

Permalink
compiler: Relax intervals with upper from not mapped dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 24, 2023
1 parent 2474a68 commit 2fa4880
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
22 changes: 10 additions & 12 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,22 +358,20 @@ def dspace(self):
# Dimension-centric view of the data space
intervals = IntervalGroup.generate('union', *parts.values())

# 'union' may have resulted in intervals stricter than needed
# e.g. issue #2235. We relax the upper interval with the upper
# from not mapped parts
for f, v in parts.items():
for d in f.dimensions:
# oobs check is not required but helps reduce
# interval reconstruction
if d in oobs and not v[d].is_Null:
intervals = intervals.set_upper(d, v[d].upper)

# E.g., `db0 -> time`, but `xi NOT-> x`
intervals = intervals.promote(lambda d: not d.is_Sub)
intervals = intervals.zero(set(intervals.dimensions) - oobs)

# Upper bound of intervals including dimensions classified for
# shifting should retain the "oobs" upper bound
for f, v in parts.items():
for i in v:
if i.dim in oobs:
try:
if intervals[i.dim].upper > v[i.dim].upper and \
bool(i.dim in f.dimensions):
intervals = intervals.ceil(v[i.dim])
except AttributeError:
pass

return DataSpace(intervals, parts)

@cached_property
Expand Down
12 changes: 5 additions & 7 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,8 @@ def negate(self):
def zero(self):
return Interval(self.dim, 0, 0, self.stamp)

def ceil(self, o):
if not self.is_compatible(o):
return self._rebuild()
return Interval(self.dim, self.lower, o.upper, self.stamp)
def set_upper(self, v=0):
return Interval(self.dim, self.lower, v, self.stamp)

def flip(self):
return Interval(self.dim, self.upper, self.lower, self.stamp)
Expand Down Expand Up @@ -497,9 +495,9 @@ def zero(self, d=None):

return IntervalGroup(intervals, relations=self.relations, mode=self.mode)

def ceil(self, o=None):
d = self.dimensions if o is None else as_tuple(o.dim)
return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self],
def set_upper(self, d, v=0):
dims = as_tuple(d)
return IntervalGroup([i.set_upper(v) if i.dim in dims else i for i in self],
relations=self.relations, mode=self.mode)

def lift(self, d=None, v=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,7 +1993,7 @@ class TestInternals(object):

@pytest.mark.parametrize('nt, offset, epass',
([1, 1, True], [1, 2, False],
[5, 1, True], [3, 5, False],
[5, 3, True], [3, 5, False],
[4, 1, True], [5, 10, False]))
def test_indirection(self, nt, offset, epass):
grid = Grid(shape=(4, 4))
Expand Down

0 comments on commit 2fa4880

Please sign in to comment.