Skip to content

Commit

Permalink
compiler: Only relax upper dspace in case of save
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Oct 18, 2023
1 parent 9ff806c commit e96bb5e
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 39 deletions.
20 changes: 6 additions & 14 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,24 +327,16 @@ def dspace(self):
# Construct the `intervals` of the DataSpace, that is a global,
# Dimension-centric view of the data space
intervals = IntervalGroup.generate('union', *parts.values())

# E.g., `db0 -> time`, but `xi NOT-> x`
intervals = intervals.promote(lambda d: not d.is_Sub)
intervals = intervals.zero(set(intervals.dimensions) - oobs)

# Intersect with intervals from buffered dimensions. Unions of
# buffered dimension intervals may result in shrinking time size
try:
proc = []
for f, v in parts.items():
if f.save:
for i in v:
if i.dim.is_Time:
proc.append(intervals[i.dim].intersection(i))
else:
proc.append(intervals[i.dim])
intervals = IntervalGroup(proc)
except AttributeError:
pass
# Buffered TimeDimensions should not shirnk their upper time offset
for f, v in parts.items():
if f.is_TimeFunction:
if f.save and not f.time_dim.is_Conditional:
intervals = intervals.ceil(v[f.time_dim])

return DataSpace(intervals, parts)

Expand Down
10 changes: 10 additions & 0 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def negate(self):
def zero(self):
return Interval(self.dim, 0, 0, self.stamp)

def ceil(self, o):
if o.is_Null:
return self._rebuild()
return Interval(self.dim, self.lower, o.upper, self.stamp)

def flip(self):
return Interval(self.dim, self.upper, self.lower, self.stamp)

Expand Down Expand Up @@ -492,6 +497,11 @@ def zero(self, d=None):

return IntervalGroup(intervals, relations=self.relations, mode=self.mode)

def ceil(self, o=None):
d = self.dimensions if o is None else as_tuple(o.dim)
return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self],
relations=self.relations)

def lift(self, d=None, v=None):
d = set(self.dimensions if d is None else as_tuple(d))
intervals = [i.lift(v) if i.dim._defines & d else i for i in self]
Expand Down
21 changes: 0 additions & 21 deletions tests/test_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,24 +752,3 @@ def test_stencil_issue_1915_v2(subdomain):
op1.apply(time_M=nt-2, u=u1)

assert np.all(u.data == u1.data)


def test_default_timeM():
"""
MFE for issue #2235
"""
grid = Grid(shape=(4, 4))

u = TimeFunction(name='u', grid=grid)
usave = TimeFunction(name='usave', grid=grid, save=5)

eqns = [Eq(u.forward, u + 1),
Eq(usave, u)]

op = Operator(eqns)

assert op.arguments()['time_M'] == 4

op.apply()

assert all(np.all(usave.data[i] == i) for i in range(4))
2 changes: 1 addition & 1 deletion tests/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@switchconfig(log_level='WARNING')
def test_segmented_incremment():
def test_segmented_increment():
"""
Test for segmented operator execution of a one-sided first order
function (increment). The corresponding set of stencil offsets in
Expand Down
21 changes: 20 additions & 1 deletion tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,25 @@ def test_modulo_dims_generation_v2(self):
assert np.all(f.data[3] == 2)
assert np.all(f.data[4] == 4)

def test_default_timeM(self):
"""
MFE for issue #2235
"""
grid = Grid(shape=(4, 4))

u = TimeFunction(name='u', grid=grid)
usave = TimeFunction(name='usave', grid=grid, save=5)

eqns = [Eq(u.forward, u + 1),
Eq(usave, u)]

op = Operator(eqns)

assert op.arguments()['time_M'] == 4
op.apply()

assert all(np.all(usave.data[i] == i) for i in range(4))


class TestSubDimension(object):

Expand Down Expand Up @@ -760,7 +779,7 @@ def test_basic(self):

eqns = [Eq(u.forward, u + 1.), Eq(u2.forward, u2 + 1.), Eq(usave, u)]
op = Operator(eqns)
op.apply()
op.apply(time_M=nt-2)
assert np.all(np.allclose(u.data[(nt-1) % 3], nt-1))
assert np.all([np.allclose(u2.data[i], i) for i in range(nt)])
assert np.all([np.allclose(usave.data[i], i*factor)
Expand Down
31 changes: 29 additions & 2 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,11 +2010,38 @@ def test_indirection(self):

op = Operator(eqns)

assert op._dspace[time].lower == 1
assert op._dspace[time].lower == 0
assert op._dspace[time].upper == 1
assert op.arguments()['time_M'] == nt - 2

op()
op.apply()

assert np.all(f.data[0] == 0.)
assert np.all(f.data[i] == 3. for i in range(1, 10))

def test_indirection_v2(self):
nt = 10
grid = Grid(shape=(4, 4))
time = grid.time_dim
x, y = grid.dimensions

f = TimeFunction(name='f', grid=grid, save=nt)
g = TimeFunction(name='g', grid=grid)

idx = time
s = Indirection(name='ofs0', mapped=idx)

eqns = [
Eq(s, idx),
Eq(f[s, x, y], g + 3.)
]

op = Operator(eqns)

assert op._dspace[time].lower == 0
assert op._dspace[time].upper == 0
assert op.arguments()['time_M'] == nt - 1

op.apply()

assert np.all(f.data[i] == 3. for i in range(1, 10))

0 comments on commit e96bb5e

Please sign in to comment.