Skip to content

Commit

Permalink
Validate input types in CorrectionWithGradient.eval
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Oct 30, 2023
1 parent 0b0a339 commit 33d6534
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
31 changes: 24 additions & 7 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,34 @@ def _get_result_size(self, inputs: dict[str, jax.Array]) -> int:
class CorrectionWithGradient:
def __init__(self, c: schema.Correction):
self._dag = CorrectionDAG(c)
self._input_names = [v.name for v in c.inputs]
self._input_vars = c.inputs
self._name = c.name

def evaluate(self, *inputs: Value) -> Value:
if (n_in := len(inputs)) != (n_expected := len(self._input_names)):
def evaluate(self, *inputs: Value) -> jax.Array:
self._check_num_inputs(inputs)
inputs_as_jax = tuple(jax.numpy.array(i) for i in inputs)
self._check_input_types(inputs_as_jax)
input_names = (v.name for v in self._input_vars)

input_dict = dict(zip(input_names, inputs_as_jax))
return self._dag.evaluate(input_dict)

def _check_num_inputs(self, inputs: tuple[Value, ...]) -> None:
if (n_in := len(inputs)) != (n_expected := len(self._input_vars)):
msg = (
f"This correction requires {n_expected} input(s), {n_in} provided."
f" Required inputs are {self._input_names}"
f" Required inputs are {[v.name for v in self._input_vars]}"
)
raise ValueError(msg)

inputs_as_jax = (jax.numpy.array(i) for i in inputs)
input_dict = dict(zip(self._input_names, inputs_as_jax))
return self._dag.evaluate(input_dict)
def _check_input_types(self, inputs: tuple[jax.Array, ...]) -> None:
for i, v in enumerate(inputs):
in_type = v.dtype
expected_type_str = self._input_vars[i].type
expected_type = {"real": float, "string": str, "int": int}[expected_type_str]
if in_type != expected_type:
msg = (
f"Variable '{self._input_vars[i].name}' has type {in_type}"
f" instead of the expected {expected_type.__name__}"
)
raise ValueError(msg)
7 changes: 7 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ def test_wrong_input_length():
cg.evaluate(0.0, 1.0)


def test_wrong_input_type():
cg = CorrectionWithGradient(schemas["scale"])

with pytest.raises(ValueError, match="Variable 'x' has type int64 instead of the expected float"):
cg.evaluate(0)


def test_missing_input():
cg = CorrectionWithGradient(schemas["scale"])

Expand Down

0 comments on commit 33d6534

Please sign in to comment.