diff --git a/src/correctionlib_gradients/_base.py b/src/correctionlib_gradients/_base.py index 04da229..00b8a1b 100644 --- a/src/correctionlib_gradients/_base.py +++ b/src/correctionlib_gradients/_base.py @@ -127,17 +127,34 @@ def _get_result_size(self, inputs: dict[str, jax.Array]) -> int: class CorrectionWithGradient: def __init__(self, c: schema.Correction): self._dag = CorrectionDAG(c) - self._input_names = [v.name for v in c.inputs] + self._input_vars = c.inputs self._name = c.name - def evaluate(self, *inputs: Value) -> Value: - if (n_in := len(inputs)) != (n_expected := len(self._input_names)): + def evaluate(self, *inputs: Value) -> jax.Array: + self._check_num_inputs(inputs) + inputs_as_jax = tuple(jax.numpy.array(i) for i in inputs) + self._check_input_types(inputs_as_jax) + input_names = (v.name for v in self._input_vars) + + input_dict = dict(zip(input_names, inputs_as_jax)) + return self._dag.evaluate(input_dict) + + def _check_num_inputs(self, inputs: tuple[Value, ...]) -> None: + if (n_in := len(inputs)) != (n_expected := len(self._input_vars)): msg = ( f"This correction requires {n_expected} input(s), {n_in} provided." - f" Required inputs are {self._input_names}" + f" Required inputs are {[v.name for v in self._input_vars]}" ) raise ValueError(msg) - inputs_as_jax = (jax.numpy.array(i) for i in inputs) - input_dict = dict(zip(self._input_names, inputs_as_jax)) - return self._dag.evaluate(input_dict) + def _check_input_types(self, inputs: tuple[jax.Array, ...]) -> None: + for i, v in enumerate(inputs): + in_type = v.dtype + expected_type_str = self._input_vars[i].type + expected_type = {"real": np.floating, "int": np.integer}[expected_type_str] + if not np.issubdtype(in_type, expected_type): + msg = ( + f"Variable '{self._input_vars[i].name}' has type {in_type}" + f" instead of the expected {expected_type.__name__}" + ) + raise ValueError(msg) diff --git a/tests/test_base.py b/tests/test_base.py index fc4f56b..2519027 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -110,6 +110,13 @@ def test_wrong_input_length(): cg.evaluate(0.0, 1.0) +def test_wrong_input_type(): + cg = CorrectionWithGradient(schemas["scale"]) + + with pytest.raises(ValueError, match="Variable 'x' has type int64 instead of the expected float"): + cg.evaluate(0) + + def test_missing_input(): cg = CorrectionWithGradient(schemas["scale"])