Skip to content

Commit

Permalink
api: cleanup fd transpose implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 13, 2023
1 parent b4264b1 commit c1ebe2f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 29 deletions.
14 changes: 3 additions & 11 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,11 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic
weights = numeric_weights(deriv_order, indices, x0)

# Enforce fixed precision FD coefficients to avoid variations in results
weights = [sympify(w).evalf(_PRECISION) for w in weights]
weights = [sympify(w).evalf(_PRECISION) for w in weights][::matvec.val]

# Transpose the FD, if necessary
indices = indices.scale(matvec.val)
if matvec == transpose:
indices = indices.transpose()

# Shift index due to staggering, if any
indices = indices.shift(-(expr.indices_ref[dim] - dim))
Expand All @@ -235,15 +236,6 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic
if not expand and indices.expr is not None:
weights = Weights(name='w', dimensions=indices.free_dim, initvalue=weights)

if matvec == transpose:
# For homogeneity, always generate e.g. `x + i0` rather than `x - i0`
# for transpose and `x + i0` for direct
indices = indices.transpose()

# Do the same for the Weights, though this is more than just a
# transposition, we also must switch to the transposed StencilDimension
weights = weights._subs(weights.dimension, -indices.free_dim)

# Inject the StencilDimension
# E.g. `x + i*h_x` into `f(x)` s.t. `f(x + i*h_x)`
expr = expr._subs(dim, indices.expr)
Expand Down
25 changes: 8 additions & 17 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,14 @@ def __repr__(self):
def spacing(self):
return self.dim.spacing

def scale(self, v):
def transpose(self):
"""
Construct a new IndexSet with all indices scaled by `v`.
Transpose the IndexSet.
"""
mapper = {self.spacing: v*self.spacing}
mapper = {self.spacing: -self.spacing}

indices = []
for i in self:
for i in reversed(self):
try:
iloc = i.xreplace(mapper)
except AttributeError:
Expand All @@ -191,22 +191,13 @@ def scale(self, v):
indices.append(iloc)

try:
expr = self.expr.xreplace(mapper)
free_dim = self.free_dim.transpose()
mapper.update({self.free_dim: -free_dim})
except AttributeError:
expr = None

return IndexSet(self.dim, indices, expr=expr, fd=self.free_dim)

def transpose(self):
"""
Transpose the IndexSet.
"""
indices = tuple(reversed(self))

free_dim = self.free_dim.transpose()
free_dim = self.free_dim

try:
expr = self.expr._subs(self.free_dim, -free_dim)
expr = self.expr.xreplace(mapper)
except AttributeError:
expr = None

Expand Down
2 changes: 1 addition & 1 deletion devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _core(expr, c, weights, mapper, sregistry):
# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
# StencilDimensions starting points
subs = {expr.weights:
expr.weights.subs(d, d - (d._max if d.backward else d._min))
expr.weights.subs(d, d - d._min)
for d in dims}
expr1 = Inc(s, uxreplace(expr.expr, subs))
processed.append(c.rebuild(exprs=expr1, ispace=ispace))
Expand Down

0 comments on commit c1ebe2f

Please sign in to comment.