Skip to content

Commit

Permalink
Explicitly test that the number of returned values matches expectations
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Oct 29, 2023
1 parent 0a83967 commit f739358
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ def test_vectorized_evaluate_scale(jit):
cg = CorrectionWithGradient(schemas["scale"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
x = np.array([0.0, 1.0])
values, grads = jax.value_and_grad(evaluate)(x)
values, grads = np.vectorize(jax.value_and_grad(evaluate))(x)
assert len(values) == len(x)
assert np.allclose(values, [1.234, 1.234])
assert len(grads) == 2
assert len(grads) == len(x)
assert grads[0] == 0.0
assert grads[1] == 0.0

Expand All @@ -162,10 +163,12 @@ def test_vectorized_evaluate_simple_uniform_binning():
values = cg.evaluate(x)
# here and below, the magic numbers have been checked by plotting
# the bins and their contents, the corresponding spline, and its derivative.
assert len(values) == 3
assert np.allclose(values, [3.47303922, 5.15686275, 1.0])

grads = np.vectorize(jax.grad(cg.evaluate))(x)
expected_grad = [0.995098039, 0.0, 0.0]
assert len(grads) == len(expected_grad)
assert np.allclose(grads, expected_grad)


Expand All @@ -174,8 +177,10 @@ def test_vectorized_evaluate_simple_nonuniform_binning():
x = [3.0, 5.0, 11.0] # 11. overflows: it tests clamping

values = cg.evaluate(x)
assert len(values) == 3
assert np.allclose(values, [2.0, 3.08611111, 1])

grads = np.vectorize(jax.grad(cg.evaluate))(x)
expected_grad = [0.794444444, 0.0, 0.0]
assert len(grads) == len(expected_grad)
assert np.allclose(grads, expected_grad)

0 comments on commit f739358

Please sign in to comment.