Skip to content

Commit

Permalink
tests: fix unevalution tests with new cross deriv
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 25, 2024
1 parent 2c51dfe commit ca2d378
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 22 deletions.
40 changes: 26 additions & 14 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,23 +222,32 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None)
side = side or self._side
x0 = self._process_x0(self.dims, x0=x0)
_x0 = frozendict({**self.x0, **x0})
if self.ndims == 1:
fd_order = fd_order or self._fd_order
method = method or self._method
weights = weights if weights is not None else self._weights
return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method,
weights=weights)

# Cross derivative

method = method or self._method
weights = weights if weights is not None else self._weights

# In case this was called on a cross derivative we need to propagate
# the call to the nested derivatibe
try:
new_expr = self.expr(x0=x0, fd_order=fd_order, side=side,
method=method, weights=weights)
except TypeError:
new_expr = self.expr

_fd_order = dict(self.fd_order.getters)
try:
_fd_order.update(fd_order or {})
_fd_order = tuple(_fd_order.values())
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)
except TypeError:
assert self.ndims == 1
_fd_order.update({self.dims[0]: fd_order or self.fd_order[0]})
except AttributeError:
raise TypeError("Multi-dimensional Derivative, input expected as a dict")
raise TypeError("fd_order incomaptible with dimensions")

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side)
_fd_order = tuple(_fd_order.values())
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method,
weights=weights, expr=new_expr)

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
Expand Down Expand Up @@ -291,7 +300,10 @@ def _xreplace(self, subs):
except AttributeError:
return new, True

new_expr = self.expr.xreplace(subs)
# Resolve nested derivatives
dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)}
new_expr = self.expr.xreplace(dsubs)

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._rebuild(subs=subs, expr=new_expr), True

Expand Down Expand Up @@ -445,7 +457,7 @@ def _eval_fd(self, expr, **kwargs):
# Step 3: Evaluate FD of the new expression
if self.method == 'RSFD':
assert len(self.dims) == 1
assert self.deriv_order == 1
assert self.deriv_order[0] == 1
res = d45(expr, self.dims[0], x0=self.x0, expand=expand)
elif len(self.dims) > 1:
assert self.method == 'FD'
Expand Down
17 changes: 9 additions & 8 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim):

@pytest.mark.parametrize('SymbolType, derivative, dim, expected', [
(Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'),
(Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'),
(Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'),
(Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dydz'], 3,
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'),
(Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dy', 'dz2'], 3,
'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'),
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'),
(TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx2', 'dy'], 3,
'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3,
'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' +
' (x, 2)), z), y, z)')
'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' +
'u(t, x, y, z), x), y), (x, 2)), z), y), z)')
])
def test_unevaluation(self, SymbolType, derivative, dim, expected):
u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)
Expand All @@ -111,13 +112,13 @@ def test_unevaluation(self, SymbolType, derivative, dim, expected):

@pytest.mark.parametrize('expr,expected', [
('u.dx + u.dy', 'Derivative(u, x) + Derivative(u, y)'),
('u.dxdy', 'Derivative(u, x, y)'),
('u.dxdy', 'Derivative(Derivative(u, x), y)'),
('u.laplace',
'Derivative(u, (x, 2)) + Derivative(u, (y, 2)) + Derivative(u, (z, 2))'),
('(u.dx + u.dy).dx', 'Derivative(Derivative(u, x) + Derivative(u, y), x)'),
('((u.dx + u.dy).dx + u.dxdy).dx',
'Derivative(Derivative(Derivative(u, x) + Derivative(u, y), x) +' +
' Derivative(u, x, y), x)'),
' Derivative(Derivative(u, x), y), x)'),
('(u**4).dx', 'Derivative(u**4, x)'),
('(u/4).dx', 'Derivative(u/4, x)'),
('((u.dx + v.dy).dx * v.dx).dy.dz',
Expand Down

0 comments on commit ca2d378

Please sign in to comment.