Skip to content

Commit ce8613a

Browse files
committed
compiler: Generate less integer arithmetic
1 parent 0ccf5fd commit ce8613a

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

devito/ir/clusters/cluster.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ def dspace(self):
297297
continue
298298

299299
intervals = [Interval(d,
300-
min([minimum(i) for i in offs]),
301-
max([maximum(i) for i in offs]))
302-
for d, offs in v.items()]
300+
min([minimum(i, ispace=self.ispace) for i in o]),
301+
max([maximum(i, ispace=self.ispace) for i in o]))
302+
for d, o in v.items()]
303303
intervals = IntervalGroup(intervals)
304304

305305
# Factor in the IterationSpace -- if the min/max points aren't zero,

devito/ir/support/utils.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -308,30 +308,42 @@ def _relational(expr, callback, udims=None):
308308
return expr.subs(mapper)
309309

310310

311-
def minimum(expr, udims=None):
311+
def minimum(expr, udims=None, ispace=None):
312312
"""
313313
Substitute the unbounded Dimensions in `expr` with their minimum point.
314314
315315
Unbounded Dimensions whose possible minimum value is not known are ignored.
316316
"""
317-
return _relational(expr, lambda e: e._min, udims)
317+
def callback(sd):
318+
try:
319+
return sd._min + ispace[sd].lower
320+
except (TypeError, KeyError):
321+
return sd._min
322+
323+
return _relational(expr, callback, udims)
318324

319325

320-
def maximum(expr, udims=None):
326+
def maximum(expr, udims=None, ispace=None):
321327
"""
322328
Substitute the unbounded Dimensions in `expr` with their maximum point.
323329
324330
Unbounded Dimensions whose possible maximum value is not known are ignored.
325331
"""
326-
return _relational(expr, lambda e: e._max, udims)
332+
def callback(sd):
333+
try:
334+
return sd._max + ispace[sd].upper
335+
except (TypeError, KeyError):
336+
return sd._max
337+
338+
return _relational(expr, callback, udims)
327339

328340

329-
def extrema(expr):
341+
def extrema(expr, ispace=None):
330342
"""
331343
The minimum and maximum extrema assumed by `expr` once the unbounded
332344
Dimensions are resolved.
333345
"""
334-
return Extrema(minimum(expr), maximum(expr))
346+
return Extrema(minimum(expr, ispace=ispace), maximum(expr, ispace=ispace))
335347

336348

337349
def minmax_index(expr, d):

devito/passes/clusters/derivatives.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry):
115115
extra = (c.ispace.itdims + dims,)
116116
ispace = IterationSpace.union(c.ispace, ispace0, relations=extra)
117117

118+
# Set the IterationSpace along the StencilDimensions to start from 0
119+
# (rather than the default `d._min`) to minimize the amount of integer
120+
# arithmetic to calculate the various index access functions
121+
for d in dims:
122+
ispace = ispace.translate(d, -d._min)
123+
118124
try:
119125
s = reusables.pop()
120126
assert s.dtype is w.dtype
@@ -125,12 +131,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry):
125131
ispace1 = ispace.project(lambda d: d is not dims[-1])
126132
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))
127133

128-
# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
129-
# StencilDimensions starting points
130-
subs = {expr.weights:
131-
expr.weights.subs(d, d - d._min)
132-
for d in dims}
133-
expr1 = Inc(s, uxreplace(expr.expr, subs))
134+
# Transform e.g. `r0[x + i0 + 2, y] -> r0[x + i0, y, z]` for alignment
135+
# with the shifted `ispace`
136+
base = expr.base
137+
for d in dims:
138+
base = base.subs(d, d + d._min)
139+
expr1 = Inc(s, base*expr.weights)
134140
processed.append(c.rebuild(exprs=expr1, ispace=ispace))
135141

136142
# Track lowered IndexDerivative for subsequent optimization by the caller

tests/test_unexpansion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_multiple_cross_derivs(self, coeffs, expected):
9393

9494
# w0, w1, ...
9595
functions = FindSymbols().visit(op)
96-
weights = [f for f in functions if isinstance(f, Weights)]
96+
weights = {f for f in functions if isinstance(f, Weights)}
9797
assert len(weights) == expected
9898

9999

0 commit comments

Comments
 (0)