Skip to content

Commit

Permalink
api: fix corner case for staggered FD
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 25, 2024
1 parent cf59cbb commit 8157110
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
6 changes: 5 additions & 1 deletion devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _process_x0(cls, dims, **kwargs):
# Only given a value
_x0 = kwargs.get('x0')
assert len(dims) == 1 or _x0 is None
if _x0 is not None:
if _x0 is not None and _x0 is not dims[0]:
x0 = frozendict({dims[0]: _x0})
else:
x0 = frozendict({})
Expand Down Expand Up @@ -360,6 +360,10 @@ def _eval_at(self, func):
# do not overwrite it
if self.x0 or self.side is not None or func.function is self.expr.function:
return self
# For basic equation of the form f = Derivative(g, ...) we can just
# compare staggering
if self.expr.staggered == func.staggered:
return self

x0 = func.indices_ref._getters
if self.expr.is_Add:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,16 @@ def test_deriv_spec(self):
assert dxy0.x0 == {y: y+y.spacing/2}
assert dxy02.x0 == {x: x+x.spacing/2}

def test_deriv_stagg_plain(self):
grid = Grid((11, 11))
x, y = grid.dimensions
f1 = Function(name="f1", grid=grid, space_order=2, staggered=NODE)
f2 = Function(name="f2", grid=grid, space_order=2, staggered=NODE)

eq0 = Eq(f1, f2.laplace).evaluate
assert eq0.rhs == f2.laplace.evaluate
assert eq0.rhs != 0


class TestTwoStageEvaluation:

Expand Down

0 comments on commit 8157110

Please sign in to comment.