From 542314a6ccbfe826b10af8f899a664f45c162c1a Mon Sep 17 00:00:00 2001 From: Enrico Guiraud Date: Sun, 29 Oct 2023 16:14:56 -0600 Subject: [PATCH] Remove CorrectionWithGradient.eval_dict It does not work well with vectorized calls: we would like to vectorize over the values of the dictionary, but e.g. np.vectorized does not do that out of the box. We'll have to revise later in what form we can reintroduce (something like) eval_dict, if any, e.g. using kwargs instead of an input dict. --- src/correctionlib_gradients/_base.py | 13 ++++--------- tests/test_base.py | 27 ++------------------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/src/correctionlib_gradients/_base.py b/src/correctionlib_gradients/_base.py index aad048f..991f218 100644 --- a/src/correctionlib_gradients/_base.py +++ b/src/correctionlib_gradients/_base.py @@ -108,16 +108,11 @@ def __init__(self, c: schema.Correction): def evaluate(self, *inputs: Value) -> Value: if (n_in := len(inputs)) != (n_expected := len(self._input_names)): - msg = f"This correction requires {n_expected} input(s), {n_in} provided" + msg = ( + f"This correction requires {n_expected} input(s), {n_in} provided." + f" Required inputs are {self._input_names}" + ) raise ValueError(msg) input_dict = dict(zip(self._input_names, inputs)) return self._dag.evaluate(input_dict) - - def eval_dict(self, inputs: dict[str, Value]) -> Value: - for n in self._input_names: - if n not in inputs: - msg = f"Variable '{n}' is required by correction '{self._name}' but is not present in input" - raise ValueError(msg) - - return self._dag.evaluate(inputs) diff --git a/tests/test_base.py b/tests/test_base.py index 2837c91..eb58a24 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -101,9 +101,9 @@ def test_missing_input(): cg = CorrectionWithGradient(schemas["scale"]) with pytest.raises( - ValueError, match="Variable 'x' is required by correction 'test scalar' but is not present in input" + ValueError, match="This correction requires 1 input\\(s\\), 0 provided. Required inputs are \\['x'\\]" ): - cg.eval_dict({}) + cg.evaluate() def test_unsupported_correction(): @@ -146,16 +146,6 @@ def test_evaluate_scale(jit): assert grad == 0.0 -@pytest.mark.parametrize("jit", [False, True]) -def test_eval_dict_scale(jit): - cg = CorrectionWithGradient(schemas["scale"]) - eval_dict = jax.jit(cg.eval_dict) if jit else cg.eval_dict - value, grad = jax.value_and_grad(eval_dict)({"x": 4.2}) - assert math.isclose(value, 1.234) - assert list(grad.keys()) == ["x"] - assert grad["x"] == 0.0 - - @pytest.mark.parametrize("jit", [False, True]) def test_vectorized_evaluate_scale(jit): cg = CorrectionWithGradient(schemas["scale"]) @@ -168,19 +158,6 @@ def test_vectorized_evaluate_scale(jit): assert grads[1] == 0.0 -@pytest.mark.parametrize("jit", [False, True]) -def test_vectorized_eval_dict_scale(jit): - cg = CorrectionWithGradient(schemas["scale"]) - eval_dict = jax.jit(cg.eval_dict) if jit else cg.eval_dict - x = np.array([0.0, 1.0]) - values, grads = jax.value_and_grad(eval_dict)({"x": x}) - assert np.allclose(values, [1.234, 1.234]) - assert list(grads.keys()) == ["x"] - assert len(grads["x"]) == 2 - assert grads["x"][0] == 0.0 - assert grads["x"][1] == 0.0 - - def test_vectorized_evaluate_simple_uniform_binning(): cg = CorrectionWithGradient(schemas["simple-uniform-binning"]) x = [3.0, 5.0, 11.0] # 11. overflows: it tests clamping