diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index c8d5f3d0da..cb94dfc388 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -445,7 +445,7 @@ def _eval_fd(self, expr, **kwargs): # Step 3: Evaluate FD of the new expression if self.method == 'RSFD': assert len(self.dims) == 1 - assert self.deriv_order == 1 + assert self.deriv_order[0] == 1 res = d45(expr, self.dims[0], x0=self.x0, expand=expand) elif len(self.dims) > 1: assert self.method == 'FD' diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 77c59ea390..da68343c77 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim): @pytest.mark.parametrize('SymbolType, derivative, dim, expected', [ (Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'), - (Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'), - (Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'), + (Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'), + (Function, ['dx2dydz'], 3, + 'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'), (Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'), (Function, ['dx2dy', 'dz2'], 3, - 'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'), + 'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'), (TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'), - (TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'), + (TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'), (TimeFunction, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'), (TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3, - 'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' + - ' (x, 2)), z), y, z)') + 'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' + + 'u(t, x, y, z), x), y), (x, 2)), z), y), z)') ]) def test_unevaluation(self, SymbolType, derivative, dim, expected): u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)