Skip to content

Commit

Permalink
Add test for Formula with parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Nov 3, 2023
1 parent ca24868 commit 9172cb3
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@
variables=["x", "y"],
),
),
"formula-with-parameters": schemav2.Correction(
name="formula with parameters",
version=2,
inputs=[schemav2.Variable(name="x", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula", expression="[0]*x + [1]", parser="TFormula", variables=["x"], parameters=[2.0, 3.0]
),
),
# this type of correction is unsupported
"categorical": schemav2.Correction(
name="categorical",
Expand Down Expand Up @@ -382,6 +391,17 @@ def test_complex_formula_nojax():
assert np.allclose(values, [26.047519582032493, 43.77948741392216])


@pytest.mark.parametrize("jit", [False, True])
def test_formula_with_parameters(jit):
cg = CorrectionWithGradient(schemas["formula-with-parameters"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
values, grads = np.vectorize(jax.value_and_grad(evaluate))([1.0, 2.0])
assert len(values) == 2
assert np.allclose(values, [5.0, 7.0])
assert len(grads) == 2
assert np.allclose(grads, [2.0, 2.0])


# TODO this does not work, seemingly because of np.vectorize
# choking on the gradients being a tuple.
# @pytest.mark.parametrize("jit", [False, True])
Expand Down

0 comments on commit 9172cb3

Please sign in to comment.