Skip to content

Commit

Permalink
Add tests for CorrectionWithGrad + jax.vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Nov 4, 2023
1 parent ceb50d8 commit 9e454c4
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_vectorized_evaluate_scale(jit):
cg = CorrectionWithGradient(schemas["scale"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
x = np.array([0.0, 1.0])
values, grads = np.vectorize(jax.value_and_grad(evaluate))(x)
values, grads = jax.vmap(jax.value_and_grad(evaluate))(x)
assert len(values) == len(x)
assert np.allclose(values, [1.234, 1.234])
assert len(grads) == len(x)
Expand All @@ -256,19 +256,23 @@ def test_vectorized_evaluate_scale(jit):
def test_mixed_scalar_array_inputs(jit):
cg = CorrectionWithGradient(schemas["scale-two-inputs"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
# using np.vectorize because jax.vmap does not do broadcasting
# and only accepts jax.Array inputs
values, grads = np.vectorize(jax.value_and_grad(evaluate))(42.0, [1.234, 8.0])
assert len(values) == 2
assert np.allclose(values, [1.234, 1.234])
assert len(grads) == 2
assert np.allclose(grads, [0.0, 0.0])

# using np.vectorize because jax.vmap does not do broadcasting
# and only accepts jax.Array inputs
values, grads = np.vectorize(jax.value_and_grad(evaluate))(jnp.array(42.0), [1.234, 8.0])
assert len(values) == 2
assert np.allclose(values, [1.234, 1.234])
assert len(grads) == 2
assert np.allclose(grads, [0.0, 0.0])

values, grads = np.vectorize(jax.value_and_grad(evaluate))(jnp.array(42.0), jnp.array([1.234, 8.0]))
values, grads = jax.vmap(jax.value_and_grad(evaluate))(jnp.array([42.0, 42.0]), jnp.array([1.234, 8.0]))
assert len(values) == 2
assert np.allclose(values, [1.234, 1.234])
assert len(grads) == 2
Expand Down Expand Up @@ -301,6 +305,7 @@ def test_vectorized_evaluate_simple_uniform_binning():
expected_values = [3.47303922, 5.15686275, 1.0]
assert np.allclose(values, expected_values)

# using np.vectorize because of https://github.com/eguiraud/correctionlib-gradients/issues/42
values, grads = np.vectorize(jax.value_and_grad(cg.evaluate))(x)
assert len(values) == 3
assert np.allclose(values, expected_values)
Expand All @@ -311,13 +316,14 @@ def test_vectorized_evaluate_simple_uniform_binning():

def test_vectorized_evaluate_simple_nonuniform_binning():
cg = CorrectionWithGradient(schemas["simple-nonuniform-binning"])
x = [3.0, 5.0, 11.0] # 11. overflows: it tests clamping
x = jnp.array([3.0, 5.0, 11.0]) # 11. overflows: it tests clamping

values = cg.evaluate(x)
assert len(values) == 3
expected_values = [2.0, 3.08611111, 1]
assert np.allclose(values, expected_values)

# using np.vectorize because of https://github.com/eguiraud/correctionlib-gradients/issues/42
values, grads = np.vectorize(jax.value_and_grad(cg.evaluate))(x)
assert len(values) == 3
assert np.allclose(values, expected_values)
Expand Down Expand Up @@ -367,8 +373,8 @@ def test_simple_formula_vectorized(jit):
cg = CorrectionWithGradient(schemas["simple-formula"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
# pass in different kinds of arrays/collections
for x in [1.0, 2.0, 3.0], np.arange(1, 4, dtype=np.float32), jnp.arange(1, 4, dtype=np.float32):
values, grads = np.vectorize(jax.value_and_grad(evaluate))(x)
for x in np.arange(1, 4, dtype=np.float32), jnp.arange(1, 4, dtype=np.float32):
values, grads = jax.vmap(jax.value_and_grad(evaluate))(x)
assert len(values) == 3
assert np.allclose(values, [1.0, 4.0, 9.0])
assert len(grads) == 3
Expand Down Expand Up @@ -397,28 +403,28 @@ def test_complex_formula_nojax():
assert np.allclose(values, [26.047519582032493, 43.77948741392216])


@pytest.mark.parametrize("jit", [False, True])
def test_complex_formula_vectorized(jit):
cg = CorrectionWithGradient(schemas["complex-formula"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
# pass in different kinds of arrays/collections
y = jnp.array([2.0, 3.0])
for x in np.array([1.0, 2.0]), jnp.array([1.0, 2.0]):
values, grads = jax.vmap(jax.value_and_grad(evaluate, argnums=[0, 1]))(x, y)
assert len(values) == 2
assert np.allclose(values, [26.047519582032493, 43.77948741392216])
assert len(grads) == 2
assert np.allclose(grads[0], [19.25876411, 30.04082763])
assert np.allclose(grads[1], [2.29401694, 3.00643777])


@pytest.mark.parametrize("jit", [False, True])
def test_formula_with_parameters(jit):
cg = CorrectionWithGradient(schemas["formula-with-parameters"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
values, grads = np.vectorize(jax.value_and_grad(evaluate))([1.0, 2.0])
x = jnp.array([1.0, 2.0])
values, grads = jax.vmap(jax.value_and_grad(evaluate))(x)
assert len(values) == 2
assert np.allclose(values, [5.0, 7.0])
assert len(grads) == 2
assert np.allclose(grads, [2.0, 2.0])


# TODO this does not work, seemingly because of np.vectorize
# choking on the gradients being a tuple.
# @pytest.mark.parametrize("jit", [False, True])
# def test_complex_formula_vectorized(jit):
# cg = CorrectionWithGradient(schemas["complex-formula"])
# evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
# # pass in different kinds of arrays/collections
# y = jnp.array(2.)
# for x in [1.0, 2.0], np.array([1., 2.]), jnp.array([1., 2.]):
# values, grads = np.vectorize(jax.value_and_grad(evaluate, argnums=[0,1]))(x, y)
# assert len(values) == 2
# assert np.allclose(values, [9.963647609000805, 14.7853985])
# assert len(grads) == 88
# assert np.allclose(grads, [2.0, 4.0, 6.0])

0 comments on commit 9e454c4

Please sign in to comment.