diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index c8d5f3d0da..45610675ac 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -222,23 +222,32 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None) side = side or self._side x0 = self._process_x0(self.dims, x0=x0) _x0 = frozendict({**self.x0, **x0}) - if self.ndims == 1: - fd_order = fd_order or self._fd_order - method = method or self._method - weights = weights if weights is not None else self._weights - return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method, - weights=weights) - - # Cross derivative + + method = method or self._method + weights = weights if weights is not None else self._weights + + # In case this was called on a cross derivative we need to propagate + # the call to the nested derivatibe + try: + new_expr = self.expr(x0=x0, fd_order=fd_order, side=side, + method=method, weights=weights) + except TypeError: + new_expr = self.expr + _fd_order = dict(self.fd_order.getters) try: _fd_order.update(fd_order or {}) - _fd_order = tuple(_fd_order.values()) - _fd_order = DimensionTuple(*_fd_order, getters=self.dims) + except TypeError: + assert self.ndims == 1 + _fd_order.update({self.dims[0]: fd_order or self.fd_order[0]}) except AttributeError: - raise TypeError("Multi-dimensional Derivative, input expected as a dict") + raise TypeError("fd_order incomaptible with dimensions") - return self._rebuild(fd_order=_fd_order, x0=_x0, side=side) + _fd_order = tuple(_fd_order.values()) + _fd_order = DimensionTuple(*_fd_order, getters=self.dims) + + return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method, + weights=weights, expr=new_expr) def _rebuild(self, *args, **kwargs): kwargs['preprocessed'] = True @@ -291,7 +300,10 @@ def _xreplace(self, subs): except AttributeError: return new, True - new_expr = self.expr.xreplace(subs) + # Resolve nested derivatives + dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)} + new_expr = self.expr.xreplace(dsubs) + subs = self._ppsubs + (subs,) # Postponed substitutions return self._rebuild(subs=subs, expr=new_expr), True @@ -445,7 +457,7 @@ def _eval_fd(self, expr, **kwargs): # Step 3: Evaluate FD of the new expression if self.method == 'RSFD': assert len(self.dims) == 1 - assert self.deriv_order == 1 + assert self.deriv_order[0] == 1 res = d45(expr, self.dims[0], x0=self.x0, expand=expand) elif len(self.dims) > 1: assert self.method == 'FD' diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 77c59ea390..2532c296d1 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim): @pytest.mark.parametrize('SymbolType, derivative, dim, expected', [ (Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'), - (Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'), - (Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'), + (Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'), + (Function, ['dx2dydz'], 3, + 'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'), (Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'), (Function, ['dx2dy', 'dz2'], 3, - 'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'), + 'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'), (TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'), - (TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'), + (TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'), (TimeFunction, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'), (TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3, - 'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' + - ' (x, 2)), z), y, z)') + 'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' + + 'u(t, x, y, z), x), y), (x, 2)), z), y), z)') ]) def test_unevaluation(self, SymbolType, derivative, dim, expected): u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2) @@ -111,13 +112,13 @@ def test_unevaluation(self, SymbolType, derivative, dim, expected): @pytest.mark.parametrize('expr,expected', [ ('u.dx + u.dy', 'Derivative(u, x) + Derivative(u, y)'), - ('u.dxdy', 'Derivative(u, x, y)'), + ('u.dxdy', 'Derivative(Derivative(u, x), y)'), ('u.laplace', 'Derivative(u, (x, 2)) + Derivative(u, (y, 2)) + Derivative(u, (z, 2))'), ('(u.dx + u.dy).dx', 'Derivative(Derivative(u, x) + Derivative(u, y), x)'), ('((u.dx + u.dy).dx + u.dxdy).dx', 'Derivative(Derivative(Derivative(u, x) + Derivative(u, y), x) +' + - ' Derivative(u, x, y), x)'), + ' Derivative(Derivative(u, x), y), x)'), ('(u**4).dx', 'Derivative(u**4, x)'), ('(u/4).dx', 'Derivative(u/4, x)'), ('((u.dx + v.dy).dx * v.dx).dy.dz',