From 9126ec38d025abfc50f2cb8964242abef630dbb0 Mon Sep 17 00:00:00 2001 From: Mathias Louboutin Date: Wed, 11 Oct 2023 13:59:13 -0400 Subject: [PATCH] api: cleanup fd transpose implementation --- .../finite_differences/finite_difference.py | 14 +++-------- devito/finite_differences/tools.py | 25 ++++++------------- devito/passes/clusters/derivatives.py | 2 +- 3 files changed, 12 insertions(+), 29 deletions(-) diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index 44859b09038..5749a82ae42 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -220,10 +220,11 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic weights = numeric_weights(deriv_order, indices, x0) # Enforce fixed precision FD coefficients to avoid variations in results - weights = [sympify(w).evalf(_PRECISION) for w in weights] + weights = [sympify(w).evalf(_PRECISION) for w in weights][::matvec.val] # Transpose the FD, if necessary - indices = indices.scale(matvec.val) + if matvec == transpose: + indices = indices.transpose() # Shift index due to staggering, if any indices = indices.shift(-(expr.indices_ref[dim] - dim)) @@ -235,15 +236,6 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic if not expand and indices.expr is not None: weights = Weights(name='w', dimensions=indices.free_dim, initvalue=weights) - if matvec == transpose: - # For homogeneity, always generate e.g. `x + i0` rather than `x - i0` - # for transpose and `x + i0` for direct - indices = indices.transpose() - - # Do the same for the Weights, though this is more than just a - # transposition, we also must switch to the transposed StencilDimension - weights = weights._subs(weights.dimension, -indices.free_dim) - # Inject the StencilDimension # E.g. `x + i*h_x` into `f(x)` s.t. `f(x + i*h_x)` expr = expr._subs(dim, indices.expr) diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index e1f209c4d6f..7ab19e76640 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -175,14 +175,14 @@ def __repr__(self): def spacing(self): return self.dim.spacing - def scale(self, v): + def transpose(self): """ - Construct a new IndexSet with all indices scaled by `v`. + Transpose the IndexSet. """ - mapper = {self.spacing: v*self.spacing} + mapper = {self.spacing: -self.spacing} indices = [] - for i in self: + for i in reversed(self): try: iloc = i.xreplace(mapper) except AttributeError: @@ -191,22 +191,13 @@ def scale(self, v): indices.append(iloc) try: - expr = self.expr.xreplace(mapper) + free_dim = self.free_dim.transpose() + mapper.update({self.free_dim: -free_dim}) except AttributeError: - expr = None - - return IndexSet(self.dim, indices, expr=expr, fd=self.free_dim) - - def transpose(self): - """ - Transpose the IndexSet. - """ - indices = tuple(reversed(self)) - - free_dim = self.free_dim.transpose() + free_dim = self.free_dim try: - expr = self.expr._subs(self.free_dim, -free_dim) + expr = self.expr.xreplace(mapper) except AttributeError: expr = None diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 58990ca9e90..01e5afea6a2 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -102,7 +102,7 @@ def _core(expr, c, weights, mapper, sregistry): # Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the # StencilDimensions starting points subs = {expr.weights: - expr.weights.subs(d, d - (d._max if d.backward else d._min)) + expr.weights.subs(d, d - d._min) for d in dims} expr1 = Inc(s, uxreplace(expr.expr, subs)) processed.append(c.rebuild(exprs=expr1, ispace=ispace))