Skip to content

Commit

Permalink
tests: fix unevalution tests with new cross deriv
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 24, 2024
1 parent 2c51dfe commit 3479ba0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _eval_fd(self, expr, **kwargs):
# Step 3: Evaluate FD of the new expression
if self.method == 'RSFD':
assert len(self.dims) == 1
assert self.deriv_order == 1
assert self.deriv_order[0] == 1
res = d45(expr, self.dims[0], x0=self.x0, expand=expand)
elif len(self.dims) > 1:
assert self.method == 'FD'
Expand Down
13 changes: 7 additions & 6 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim):

@pytest.mark.parametrize('SymbolType, derivative, dim, expected', [
(Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'),
(Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'),
(Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'),
(Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dydz'], 3,
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'),
(Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dy', 'dz2'], 3,
'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'),
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'),
(TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx2', 'dy'], 3,
'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3,
'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' +
' (x, 2)), z), y, z)')
'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' +
'u(t, x, y, z), x), y), (x, 2)), z), y), z)')
])
def test_unevaluation(self, SymbolType, derivative, dim, expected):
u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)
Expand Down

0 comments on commit 3479ba0

Please sign in to comment.