diff --git a/tests/test_base.py b/tests/test_base.py index 080bf4c..aeab5d7 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -148,9 +148,10 @@ def test_vectorized_evaluate_scale(jit): cg = CorrectionWithGradient(schemas["scale"]) evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate x = np.array([0.0, 1.0]) - values, grads = jax.value_and_grad(evaluate)(x) + values, grads = np.vectorize(jax.value_and_grad(evaluate))(x) + assert len(values) == len(x) assert np.allclose(values, [1.234, 1.234]) - assert len(grads) == 2 + assert len(grads) == len(x) assert grads[0] == 0.0 assert grads[1] == 0.0 @@ -162,10 +163,12 @@ def test_vectorized_evaluate_simple_uniform_binning(): values = cg.evaluate(x) # here and below, the magic numbers have been checked by plotting # the bins and their contents, the corresponding spline, and its derivative. + assert len(values) == 3 assert np.allclose(values, [3.47303922, 5.15686275, 1.0]) grads = np.vectorize(jax.grad(cg.evaluate))(x) expected_grad = [0.995098039, 0.0, 0.0] + assert len(grads) == len(expected_grad) assert np.allclose(grads, expected_grad) @@ -174,8 +177,10 @@ def test_vectorized_evaluate_simple_nonuniform_binning(): x = [3.0, 5.0, 11.0] # 11. overflows: it tests clamping values = cg.evaluate(x) + assert len(values) == 3 assert np.allclose(values, [2.0, 3.08611111, 1]) grads = np.vectorize(jax.grad(cg.evaluate))(x) expected_grad = [0.794444444, 0.0, 0.0] + assert len(grads) == len(expected_grad) assert np.allclose(grads, expected_grad)