From 602c448a2772a072af6226c2471439d9e84628ff Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 11 Sep 2024 13:33:13 +0000 Subject: [PATCH] tests: Update after dropping redundant evaluation --- tests/test_differentiable.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index 6428c343b3..78abf95524 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -61,31 +61,34 @@ def test_interp(): a = Function(name="a", grid=grid, staggered=NODE) sa = Function(name="as", grid=grid, staggered=x) - sp_diff = lambda a, b: sympy.simplify(a - b) == 0 + def sp_diff(a, b): + a = getattr(a, 'evaluate', a) + b = getattr(b, 'evaluate', b) + return sympy.simplify(a - b) == 0 # Base case, no interp - assert interp_for_fd(a, {}, expand=True) == a - assert interp_for_fd(a, {x: x}, expand=True) == a - assert interp_for_fd(sa, {}, expand=True) == sa - assert interp_for_fd(sa, {x: x + x.spacing/2}, expand=True) == sa + assert interp_for_fd(a, {}) == a + assert interp_for_fd(a, {x: x}) == a + assert interp_for_fd(sa, {}) == sa + assert interp_for_fd(sa, {x: x + x.spacing/2}) == sa # Base case, interp - assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}, expand=True), + assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}), .5*a + .5*a.shift(x, x.spacing)) - assert sp_diff(interp_for_fd(sa, {x: x}, expand=True), + assert sp_diff(interp_for_fd(sa, {x: x}), .5*sa + .5*sa.shift(x, -x.spacing)) # Mul case, split interp - assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}, expand=True), - sa * interp_for_fd(a, {x: x + x.spacing/2}, expand=True)) - assert sp_diff(interp_for_fd(a*sa, {x: x}, expand=True), - a * interp_for_fd(sa, {x: x}, expand=True)) + assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}), + sa * interp_for_fd(a, {x: x + x.spacing/2})) + assert sp_diff(interp_for_fd(a*sa, {x: x}), + a * interp_for_fd(sa, {x: x})) # Add case, split interp - assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}, expand=True), - sa + interp_for_fd(a, {x: x + x.spacing/2}, expand=True)) - assert sp_diff(interp_for_fd(a + sa, {x: x}, expand=True), - a + interp_for_fd(sa, {x: x}, expand=True)) + assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}), + sa + interp_for_fd(a, {x: x + x.spacing/2})) + assert sp_diff(interp_for_fd(a + sa, {x: x}), + a + interp_for_fd(sa, {x: x})) @pytest.mark.parametrize('ndim', [1, 2, 3])