Skip to content

Commit

Permalink
Fix: return an array of the correct size from scale corrections, add …
Browse files Browse the repository at this point in the history
…test
  • Loading branch information
eguiraud committed Oct 29, 2023
1 parent 0aac516 commit fa225b2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def eval_spline_bwd(res, g): # type: ignore[no-untyped-def]
class CorrectionDAG:
"""A JAX-friendly representation of a correctionlib.schemav2.Correction's DAG."""

input_vars: list[schema.Variable]
input_names: list[str]
node: DAGNode

def __init__(self, c: schema.Correction):
Expand All @@ -59,7 +59,7 @@ def __init__(self, c: schema.Correction):
- correctionlib.schema.Formula -> FormulaAST, a JAX-friendly formula evaluator object.
- [TODO] Binning nodes with constant bin contents -> differentiable relaxation.
"""
self.input_vars = c.inputs
self.input_names = [v.name for v in c.inputs]
match c.data:
case float(x):
self.node = x
Expand All @@ -79,10 +79,15 @@ def __init__(self, c: schema.Correction):
msg = f"Correction '{c.name}' contains the unsupported operation type '{type(c.data).__name__}'"
raise ValueError(msg)

def evaluate(self, inputs: dict[str, Value]) -> Value:
def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
result_size = self._get_result_size(inputs)

match self.node:
case float(x):
return x
if result_size == 0:
return jax.numpy.array(x)
else:
return jax.numpy.array([x] * result_size)
case schema.Binning(edges=_edges, content=[*_values], input=_var, flow="clamp"):
# to make mypy happy
var: str = _var # type: ignore[has-type]
Expand All @@ -99,6 +104,25 @@ def evaluate(self, inputs: dict[str, Value]) -> Value:
msg = "Unsupported type of node in the computation graph. This should never happen."
raise RuntimeError(msg)

def _get_result_size(self, inputs: dict[str, jax.Array]) -> int:
"""Calculate what size the result of a DAG evaluation should have.
The size is equal to the one, common size (shape[0], or number or rows) of all
the non-scalar inputs we require, or 0 if all inputs are scalar.
An error is thrown in case the shapes of two non-scalar inputs differ.
"""
result_shape: tuple[int, ...] = ()
for value in inputs.values():
if result_shape == ():
result_shape = value.shape
elif value.shape != result_shape:
msg = "The shapes of all non-scalar inputs should match."
raise ValueError(msg)
if result_shape != ():
return result_shape[0]
else:
return 0


class CorrectionWithGradient:
def __init__(self, c: schema.Correction):
Expand All @@ -114,5 +138,6 @@ def evaluate(self, *inputs: Value) -> Value:
)
raise ValueError(msg)

input_dict = dict(zip(self._input_names, inputs))
inputs_as_jax = (jax.numpy.array(i) for i in inputs)
input_dict = dict(zip(self._input_names, inputs_as_jax))
return self._dag.evaluate(input_dict)
10 changes: 10 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def test_unsupported_flow_type():
CorrectionWithGradient(schemas["simple-nonuniform-binning-flow-default"])


def test_evaluate_scale_nojax():
cg = CorrectionWithGradient(schemas["scale"])
value = cg.evaluate(4.2)
assert math.isclose(value, 1.234)

values = cg.evaluate([4.2, 4.2])
assert len(values) == 2
assert np.allclose(values, [1.234, 1.234])


@pytest.mark.parametrize("jit", [False, True])
def test_evaluate_scale(jit):
cg = CorrectionWithGradient(schemas["scale"])
Expand Down

0 comments on commit fa225b2

Please sign in to comment.