Skip to content

Commit

Permalink
Add support for FormulaRefs inside Binning corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Dec 1, 2023
1 parent c1d689a commit 89d84ce
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
6 changes: 3 additions & 3 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def __init__(self, c: schema.Correction):
if all(isinstance(v, float) for v in values): # type: ignore[has-type]
# simple binning
self.node = SplineWithGrad.from_binning(c.data)
elif all(isinstance(v, schema.Formula) for v in values): # type: ignore[has-type]
self.node = CompoundBinning(binning)
elif all(isinstance(v, (schema.Formula, schema.FormulaRef)) for v in values): # type: ignore[has-type]
self.node = CompoundBinning(binning, c.generic_formulas)
else:
msg = (
f"Correction '{c.name}' contains a Binning correction but the bin contents"
" are neither all scalars nor all Formulas. This is not supported."
" are neither all scalars nor all Formulas/FormulaRefs. This is not supported."
)
raise ValueError(msg)
case schema.Binning(flow=flow):
Expand Down
14 changes: 9 additions & 5 deletions src/correctionlib_gradients/_compound_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CompoundBinning:
edges: jax.Array
values: list[FormulaDAG]

def __init__(self, b: schema.Binning):
def __init__(self, b: schema.Binning, generic_formulas: list[schema.Formula]):
# nothing else is supported
assert b.flow == "clamp" # noqa: S101

Expand All @@ -24,12 +24,16 @@ def __init__(self, b: schema.Binning):
else:
self.edges = jnp.array(b.edges)

variable = schema.Variable(name=self.var, type="real")
self.values = []
for value in b.content:
assert isinstance(value, schema.Formula) # noqa: S101
variable = schema.Variable(name=self.var, type="real")
formula = FormulaDAG(value, inputs=[variable])
self.values.append(formula)
if isinstance(value, schema.FormulaRef):
formula = generic_formulas[value.index].copy()
formula.parameters = value.parameters
self.values.append(FormulaDAG(formula, inputs=[variable]))
else:
assert isinstance(value, schema.Formula) # noqa: S101
self.values.append(FormulaDAG(value, inputs=[variable]))

def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
x = inputs[self.var]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,11 @@ def test_compound_nonuniform_binning():


def test_compound_binning_with_formularef():
with pytest.raises(
ValueError,
match=(
"Correction 'reftest' contains a Binning correction but "
"the bin contents are neither all scalars nor all Formulas. This is not supported."
),
):
CorrectionWithGradient(schemas["compound-binning-with-formularef"])
cg = CorrectionWithGradient(schemas["compound-binning-with-formularef"])

value = cg.evaluate(0.5)
assert math.isclose(value, 0.2)

value, grad = jax.value_and_grad(cg.evaluate)(0.5)
assert math.isclose(value, 0.2)
assert math.isclose(grad, 0.2)

0 comments on commit 89d84ce

Please sign in to comment.