From 9172cb3955a23b7208531df3a5d5fcf4519941db Mon Sep 17 00:00:00 2001 From: Enrico Guiraud Date: Fri, 3 Nov 2023 16:37:00 -0600 Subject: [PATCH] Add test for Formula with parameters --- tests/test_base.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_base.py b/tests/test_base.py index 4853d50..14ff657 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -131,6 +131,15 @@ variables=["x", "y"], ), ), + "formula-with-parameters": schemav2.Correction( + name="formula with parameters", + version=2, + inputs=[schemav2.Variable(name="x", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", expression="[0]*x + [1]", parser="TFormula", variables=["x"], parameters=[2.0, 3.0] + ), + ), # this type of correction is unsupported "categorical": schemav2.Correction( name="categorical", @@ -382,6 +391,17 @@ def test_complex_formula_nojax(): assert np.allclose(values, [26.047519582032493, 43.77948741392216]) +@pytest.mark.parametrize("jit", [False, True]) +def test_formula_with_parameters(jit): + cg = CorrectionWithGradient(schemas["formula-with-parameters"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + values, grads = np.vectorize(jax.value_and_grad(evaluate))([1.0, 2.0]) + assert len(values) == 2 + assert np.allclose(values, [5.0, 7.0]) + assert len(grads) == 2 + assert np.allclose(grads, [2.0, 2.0]) + + # TODO this does not work, seemingly because of np.vectorize # choking on the gradients being a tuple. # @pytest.mark.parametrize("jit", [False, True])