Skip to content

Commit

Permalink
Add support for Formula corrections + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Nov 2, 2023
1 parent 9e7c4b3 commit eb9d214
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 28 deletions.
36 changes: 12 additions & 24 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import correctionlib.schemav2 as schema
import jax
import jax.numpy as jnp
import numpy as np
from scipy.interpolate import CubicSpline # type: ignore[import-untyped]

import correctionlib_gradients._utils as utils
from correctionlib_gradients._formuladag import FormulaDAG
from correctionlib_gradients._typedefs import Value


Expand Down Expand Up @@ -41,7 +44,7 @@ def eval_spline_bwd(res, g): # type: ignore[no-untyped-def]
return cast(Callable[[Value], Value], eval_spline)


DAGNode: TypeAlias = float | schema.Binning
DAGNode: TypeAlias = float | schema.Binning | FormulaDAG


class CorrectionDAG:
Expand Down Expand Up @@ -75,19 +78,21 @@ def __init__(self, c: schema.Correction):
flow = cast(str, flow) # type: ignore[has-type]
msg = f"Correction '{c.name}' contains a Binning correction with `{flow=}`. Only 'clamp' is supported."
raise ValueError(msg)
case schema.Formula() as f:
self.node = FormulaDAG(f, c.inputs)
case _:
msg = f"Correction '{c.name}' contains the unsupported operation type '{type(c.data).__name__}'"
raise ValueError(msg)

def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
result_size = self._get_result_size(inputs)
result_size = utils.get_result_size(inputs)

match self.node:
case float(x):
if result_size == 0:
return jax.numpy.array(x)
return jnp.array(x)
else:
return jax.numpy.array([x] * result_size)
return jnp.repeat(x, result_size)
case schema.Binning(edges=_edges, content=[*_values], input=_var, flow="clamp"):
# to make mypy happy
var: str = _var # type: ignore[has-type]
Expand All @@ -100,29 +105,12 @@ def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
xs = np.array(edges)
s = make_differentiable_spline(xs, values)
return s(inputs[var])
case FormulaDAG() as f:
return f.evaluate(inputs)
case _: # pragma: no cover
msg = "Unsupported type of node in the computation graph. This should never happen."
raise RuntimeError(msg)

def _get_result_size(self, inputs: dict[str, jax.Array]) -> int:
"""Calculate what size the result of a DAG evaluation should have.
The size is equal to the one, common size (shape[0], or number or rows) of all
the non-scalar inputs we require, or 0 if all inputs are scalar.
An error is thrown in case the shapes of two non-scalar inputs differ.
"""
result_shape: tuple[int, ...] = ()
for value in inputs.values():
if result_shape == ():
result_shape = value.shape
elif value.shape != result_shape:
msg = "The shapes of all non-scalar inputs should match."
raise ValueError(msg)
if result_shape != ():
return result_shape[0]
else:
return 0


class CorrectionWithGradient:
def __init__(self, c: schema.Correction):
Expand All @@ -132,7 +120,7 @@ def __init__(self, c: schema.Correction):

def evaluate(self, *inputs: Value) -> jax.Array:
self._check_num_inputs(inputs)
inputs_as_jax = tuple(jax.numpy.array(i) for i in inputs)
inputs_as_jax = tuple(jnp.array(i) for i in inputs)
self._check_input_types(inputs_as_jax)
input_names = (v.name for v in self._input_vars)

Expand Down
224 changes: 224 additions & 0 deletions src/correctionlib_gradients/_formuladag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# SPDX-FileCopyrightText: 2023-present Enrico Guiraud <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
from enum import Enum, auto
from typing import TypeAlias, Union

import correctionlib.schemav2 as schema
import jax
import jax.numpy as jnp # atan2
from correctionlib._core import Formula, FormulaAst
from correctionlib._core import Variable as CPPVariable

import correctionlib_gradients._utils as utils


@dataclass
class Literal:
value: float


@dataclass
class Variable:
name: str


@dataclass
class Parameter:
idx: int


class BinaryOp(Enum):
EQUAL = auto()
NOTEQUAL = auto()
GREATER = auto()
LESS = auto()
GREATEREQ = auto()
LESSEQ = auto()
MINUS = auto()
PLUS = auto()
DIV = auto()
TIMES = auto()
POW = auto()
ATAN2 = auto()
MAX = auto()
MIN = auto()


class UnaryOp(Enum):
NEGATIVE = auto()
LOG = auto()
LOG10 = auto()
EXP = auto()
ERF = auto()
SQRT = auto()
ABS = auto()
COS = auto()
SIN = auto()
TAN = auto()
ACOS = auto()
ASIN = auto()
ATAN = auto()
COSH = auto()
SINH = auto()
TANH = auto()
ACOSH = auto()
ASINH = auto()
ATANH = auto()


FormulaNode: TypeAlias = Union[Literal, Variable, Parameter, "Op"]


@dataclass
class Op:
op: BinaryOp | UnaryOp
children: tuple[FormulaNode, ...]


class FormulaDAG:
def __init__(self, f: schema.Formula, inputs: list[schema.Variable]):
cpp_formula = Formula.from_string(f.json(), [CPPVariable.from_string(v.json()) for v in inputs])
self.input_names = [v.name for v in inputs]
self.node: FormulaNode = self._make_node(cpp_formula.ast)

def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
res = self._eval_node(self.node, inputs)
return res

def _eval_node(self, node: FormulaNode, inputs: dict[str, jax.Array]) -> jax.Array:
match node:
case Literal(value):
res_size = utils.get_result_size(inputs)
if res_size == 0:
return jnp.array(value)
else:
return jnp.repeat(value, res_size)
case Variable(name):
return inputs[name]
case Op(op=BinaryOp(), children=children):
c1, c2 = children
ev = self._eval_node
i = inputs
match node.op:
case BinaryOp.EQUAL:
return (ev(c1, i) == ev(c2, i)) + 0.0
case BinaryOp.NOTEQUAL:
return (ev(c1, i) != ev(c2, i)) + 0.0
case BinaryOp.GREATER:
return (ev(c1, i) > ev(c2, i)) + 0.0
case BinaryOp.LESS:
return (ev(c1, i) < ev(c2, i)) + 0.0
case BinaryOp.GREATEREQ:
return (ev(c1, i) >= ev(c2, i)) + 0.0
case BinaryOp.LESSEQ:
return (ev(c1, i) <= ev(c2, i)) + 0.0
case BinaryOp.MINUS:
return ev(c1, i) - ev(c2, i)
case BinaryOp.PLUS:
return ev(c1, i) + ev(c2, i)
case BinaryOp.DIV:
return ev(c1, i) / ev(c2, i)
case BinaryOp.TIMES:
return ev(c1, i) * ev(c2, i)
case BinaryOp.POW:
return ev(c1, i) ** ev(c2, i)
case BinaryOp.ATAN2:
return jnp.arctan2(ev(c1, i), ev(c2, i))
case BinaryOp.MAX:
return jnp.max(jnp.stack([ev(c1, i), ev(c2, i)]))
case BinaryOp.MIN:
return jnp.min(jnp.stack([ev(c1, i), ev(c2, i)]))
case _: # pragma: no cover
msg = f"Type of formula node not recognized ({node}). This should never happen."
raise RuntimeError(msg)

# never reached, only here to make mypy happy
return jax.array() # pragma: no cover

def _make_node(self, ast: FormulaAst) -> FormulaNode:
match ast.nodetype:
case FormulaAst.NodeType.LITERAL:
return Literal(ast.data)
case FormulaAst.NodeType.VARIABLE:
return Variable(self.input_names[ast.data])
case FormulaAst.NodeType.BINARY:
match ast.data:
# TODO reduce code duplication (code generation?)
case FormulaAst.BinaryOp.EQUAL:
return Op(
op=BinaryOp.EQUAL,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.NOTEQUAL:
return Op(
op=BinaryOp.NOTEQUAL,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.GREATER:
return Op(
op=BinaryOp.GREATER,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.LESS:
return Op(
op=BinaryOp.LESS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.GREATEREQ:
return Op(
op=BinaryOp.GREATEREQ,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.LESSEQ:
return Op(
op=BinaryOp.LESSEQ,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MINUS:
return Op(
op=BinaryOp.MINUS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.PLUS:
return Op(
op=BinaryOp.PLUS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.DIV:
return Op(
op=BinaryOp.DIV,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.TIMES:
return Op(
op=BinaryOp.TIMES,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.POW:
return Op(
op=BinaryOp.POW,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.ATAN2:
return Op(
op=BinaryOp.ATAN2,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MAX:
return Op(
op=BinaryOp.MAX,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MIN:
return Op(
op=BinaryOp.MIN,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case _: # pragma: no cover
msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen."
raise ValueError(msg)

# never reached, just to make mypy happy
return Literal(0.0) # pragma: no cover
24 changes: 24 additions & 0 deletions src/correctionlib_gradients/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: 2023-present Enrico Guiraud <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
import jax


def get_result_size(inputs: dict[str, jax.Array]) -> int:
"""Calculate what size the result of a DAG evaluation should have.
The size is equal to the one, common size (shape[0], or number or rows) of all
the non-scalar inputs we require, or 0 if all inputs are scalar.
An error is thrown in case the shapes of two non-scalar inputs differ.
"""
result_shape: tuple[int, ...] = ()
for value in inputs.values():
if result_shape == ():
result_shape = value.shape
elif value.shape != result_shape:
msg = "The shapes of all non-scalar inputs should match."
raise ValueError(msg)
if result_shape != ():
return result_shape[0]
else:
return 0
Loading

0 comments on commit eb9d214

Please sign in to comment.