Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-compute spline once when bin contents are constant #41

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 7 additions & 51 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,19 @@
# SPDX-FileCopyrightText: 2023-present Enrico Guiraud <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
from typing import Callable, Iterable, TypeAlias, cast
from typing import TypeAlias, cast

import correctionlib.schemav2 as schema
import jax
import jax.numpy as jnp
import numpy as np
from scipy.interpolate import CubicSpline # type: ignore[import-untyped]

import correctionlib_gradients._utils as utils
from correctionlib_gradients._differentiable_spline import SplineWithGrad, make_differentiable_spline
from correctionlib_gradients._formuladag import FormulaDAG
from correctionlib_gradients._typedefs import Value


def midpoints(x: jax.Array) -> jax.Array:
return 0.5 * (x[1:] + x[:-1])


# TODO: the callable returned is not traceable by JAX, so it does not support jax.jit, jax.vmap etc.
# TODO: would it make more sense if the value returned was the exact value given by the binning,
# while the derivative is calculated by spline.derivative?
def make_differentiable_spline(x: jax.Array, y: jax.Array) -> Callable[[Value], Value]:
spline = CubicSpline(midpoints(x), y, bc_type="clamped")
dspline = spline.derivative(1)

def clip(x: Value) -> Value:
# so that extrapolation works
return np.clip(x, spline.x[0], spline.x[-1])

@jax.custom_vjp
def eval_spline(x): # type: ignore[no-untyped-def]
return spline(clip(x))

def eval_spline_fwd(x): # type: ignore[no-untyped-def]
return eval_spline(x), dspline(clip(x))

def eval_spline_bwd(res, g): # type: ignore[no-untyped-def]
return ((res * g),)

eval_spline.defvjp(eval_spline_fwd, eval_spline_bwd)

return cast(Callable[[Value], Value], eval_spline)


DAGNode: TypeAlias = float | schema.Binning | FormulaDAG
DAGNode: TypeAlias = float | SplineWithGrad | FormulaDAG


class CorrectionDAG:
Expand All @@ -60,7 +29,7 @@ def __init__(self, c: schema.Correction):

Transformations applied:
- correctionlib.schema.Formula -> FormulaAST, a JAX-friendly formula evaluator object.
- [TODO] Binning nodes with constant bin contents -> differentiable relaxation.
- Binning nodes with constant bin contents -> differentiable relaxation.
"""
self.input_names = [v.name for v in c.inputs]
match c.data:
Expand All @@ -73,7 +42,7 @@ def __init__(self, c: schema.Correction):
" (one or more of the bin contents are not simple scalars). This is not supported."
)
raise ValueError(msg)
self.node = c.data
self.node = make_differentiable_spline(c.data)
case schema.Binning(flow=flow):
flow = cast(str, flow) # type: ignore[has-type]
msg = f"Correction '{c.name}' contains a Binning correction with `{flow=}`. Only 'clamp' is supported."
Expand All @@ -93,23 +62,10 @@ def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
return jnp.array(x)
else:
return jnp.repeat(x, result_size)
case schema.Binning(edges=_edges, content=[*_values], input=_var, flow="clamp"):
# to make mypy happy
var: str = _var # type: ignore[has-type]
edges: Iterable[float] | schema.UniformBinning = _edges # type: ignore[has-type]
values: jax.Array = _values # type: ignore[has-type]

if isinstance(edges, schema.UniformBinning):
xs = np.linspace(edges.low, edges.high, edges.n + 1)
else:
xs = np.array(edges)
s = make_differentiable_spline(xs, values)
return s(inputs[var])
case SplineWithGrad() as s:
return s(inputs[s.var])
case FormulaDAG() as f:
return f.evaluate(inputs)
case _: # pragma: no cover
msg = "Unsupported type of node in the computation graph. This should never happen."
raise RuntimeError(msg)


class CorrectionWithGradient:
Expand Down
61 changes: 61 additions & 0 deletions src/correctionlib_gradients/_differentiable_spline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: 2023-present Enrico Guiraud <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
from typing import Callable

import jax
import numpy as np
from correctionlib.schemav2 import Binning, UniformBinning
from scipy.interpolate import CubicSpline # type: ignore[import-untyped]

from correctionlib_gradients._typedefs import Value


def _midpoints(x: jax.Array) -> jax.Array:
return 0.5 * (x[1:] + x[:-1])


@dataclass
class SplineWithGrad:
spline: Callable[[Value], Value]
var: str

def __call__(self, v: Value) -> Value:
return self.spline(v)


# TODO: the callable returned is not traceable by JAX, so it does not support jax.jit, jax.vmap etc.
# TODO: would it make more sense if the value returned was the exact value given by the binning,
# while the derivative is calculated by spline.derivative?
def make_differentiable_spline(b: Binning) -> SplineWithGrad:
var: str = b.input

if isinstance(b.edges, UniformBinning):
xs = np.linspace(b.edges.low, b.edges.high, b.edges.n + 1)
else:
xs = np.array(b.edges)

ys = b.content
assert all(isinstance(y, float) for y in ys) # noqa: S101

spline = CubicSpline(_midpoints(xs), ys, bc_type="clamped")
dspline = spline.derivative(1)

def clip(x: Value) -> Value:
# so that extrapolation works
return np.clip(x, spline.x[0], spline.x[-1])

@jax.custom_vjp
def eval_spline(x): # type: ignore[no-untyped-def]
return spline(clip(x))

def eval_spline_fwd(x): # type: ignore[no-untyped-def]
return eval_spline(x), dspline(clip(x))

def eval_spline_bwd(res, g): # type: ignore[no-untyped-def]
return ((res * g),)

eval_spline.defvjp(eval_spline_fwd, eval_spline_bwd)

return SplineWithGrad(spline=eval_spline, var=var)