Skip to content

Commit

Permalink
tests: Update after dropping redundant evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Sep 24, 2024
1 parent 1563b80 commit 602c448
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,34 @@ def test_interp():
a = Function(name="a", grid=grid, staggered=NODE)
sa = Function(name="as", grid=grid, staggered=x)

sp_diff = lambda a, b: sympy.simplify(a - b) == 0
def sp_diff(a, b):
a = getattr(a, 'evaluate', a)
b = getattr(b, 'evaluate', b)
return sympy.simplify(a - b) == 0

# Base case, no interp
assert interp_for_fd(a, {}, expand=True) == a
assert interp_for_fd(a, {x: x}, expand=True) == a
assert interp_for_fd(sa, {}, expand=True) == sa
assert interp_for_fd(sa, {x: x + x.spacing/2}, expand=True) == sa
assert interp_for_fd(a, {}) == a
assert interp_for_fd(a, {x: x}) == a
assert interp_for_fd(sa, {}) == sa
assert interp_for_fd(sa, {x: x + x.spacing/2}) == sa

# Base case, interp
assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}, expand=True),
assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}),
.5*a + .5*a.shift(x, x.spacing))
assert sp_diff(interp_for_fd(sa, {x: x}, expand=True),
assert sp_diff(interp_for_fd(sa, {x: x}),
.5*sa + .5*sa.shift(x, -x.spacing))

# Mul case, split interp
assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}, expand=True),
sa * interp_for_fd(a, {x: x + x.spacing/2}, expand=True))
assert sp_diff(interp_for_fd(a*sa, {x: x}, expand=True),
a * interp_for_fd(sa, {x: x}, expand=True))
assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}),
sa * interp_for_fd(a, {x: x + x.spacing/2}))
assert sp_diff(interp_for_fd(a*sa, {x: x}),
a * interp_for_fd(sa, {x: x}))

# Add case, split interp
assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}, expand=True),
sa + interp_for_fd(a, {x: x + x.spacing/2}, expand=True))
assert sp_diff(interp_for_fd(a + sa, {x: x}, expand=True),
a + interp_for_fd(sa, {x: x}, expand=True))
assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}),
sa + interp_for_fd(a, {x: x + x.spacing/2}))
assert sp_diff(interp_for_fd(a + sa, {x: x}),
a + interp_for_fd(sa, {x: x}))


@pytest.mark.parametrize('ndim', [1, 2, 3])
Expand Down

0 comments on commit 602c448

Please sign in to comment.