Skip to content

Commit

Permalink
Solver-test-case refactor (#469)
Browse files Browse the repository at this point in the history
* Todo in taylor.py

* Formatting in test case file

* repr in implementations

* Test utilities

* cubature is cubature_rule now

* Cleaned up conftest

* fixed grid and while loop tests updated

* Solve and save at updated

* simulate terminal values updated

* Tests for dense output updated

* test_edges -> test_misc

* Fixed grid differentiability tests

* JVP tests for fixed grid solvers

* Improved test readability

* Update and rerun internal benchmark

* Removed debug_nan flag

* Cubature rule function in DenseSLR1

* SLR0 takes cubature factory

* Updated benchmark

* Fixed a doctest
  • Loading branch information
pnkraemer authored Mar 16, 2023
1 parent 2516e8c commit afb3783
Show file tree
Hide file tree
Showing 19 changed files with 2,961 additions and 3,231 deletions.
5,153 changes: 2,396 additions & 2,757 deletions docs/benchmarks/lotka_volterra/internal.ipynb

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions docs/benchmarks/lotka_volterra/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ jupyter:
Let's find the fastest solver of the Lotka-Volterra problem, a standard benchmark problem. It is low-dimensional, not stiff, and generally poses no major problems for any numerical solver.

```python
import functools

import jax
import jax.experimental.ode
import jax.numpy as jnp
Expand Down Expand Up @@ -126,22 +128,22 @@ def solver_to_method(solver):
Should we linearize with a Taylor-approximation or by moment matching?

```python
def cubature_to_slr1(cubature, *, ode_shape):
def cubature_to_slr1(cubature_rule_fn, *, ode_shape):
return recipes.DenseSLR1.from_params(
ode_shape=ode_shape,
cubature=cubature,
cubature_rule_fn=cubature_rule_fn,
)


# Different linearisation styles
ode_shape = u0.shape
ts1 = recipes.DenseTS1.from_params(ode_shape=ode_shape)
sci = cubature.ThirdOrderSpherical.from_params(input_shape=ode_shape)
ut = cubature.UnscentedTransform.from_params(input_shape=ode_shape, r=1.0)
gh = cubature.GaussHermite.from_params(input_shape=ode_shape, degree=3)
slr1_sci = cubature_to_slr1(sci, ode_shape=ode_shape)
slr1_ut = cubature_to_slr1(ut, ode_shape=ode_shape)
slr1_gh = cubature_to_slr1(gh, ode_shape=ode_shape)
sci_fn = cubature.ThirdOrderSpherical.from_params
ut_fn = functools.partial(cubature.UnscentedTransform.from_params, r=1.0)
gh_fn = functools.partial(cubature.GaussHermite.from_params, degree=3)
slr1_sci = cubature_to_slr1(sci_fn, ode_shape=ode_shape)
slr1_ut = cubature_to_slr1(ut_fn, ode_shape=ode_shape)
slr1_gh = cubature_to_slr1(gh_fn, ode_shape=ode_shape)


# Methods
Expand Down
24 changes: 12 additions & 12 deletions probdiffeq/implementations/_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,30 +184,30 @@ def complete_correction(self, extrapolated, cache):

@jax.tree_util.register_pytree_node_class
class StatisticalFirstOrder(_collections.AbstractCorrection):
def __init__(self, ode_order, cubature):
def __init__(self, ode_order, cubature_rule):
if ode_order > 1:
raise ValueError

super().__init__(ode_order=ode_order)
self.cubature = cubature
self.cubature_rule = cubature_rule

@classmethod
def from_params(cls, ode_order):
sci_fn = cubature_module.ThirdOrderSpherical.from_params
cubature = sci_fn(input_shape=())
return cls(ode_order=ode_order, cubature=cubature)
cubature_rule = sci_fn(input_shape=())
return cls(ode_order=ode_order, cubature_rule=cubature_rule)

def tree_flatten(self):
# todo: should this call super().tree_flatten()?
children = (self.cubature,)
children = (self.cubature_rule,)
aux = (self.ode_order,)
return children, aux

@classmethod
def tree_unflatten(cls, aux, children):
(cubature,) = children
(cubature_rule,) = children
(ode_order,) = aux
return cls(ode_order=ode_order, cubature=cubature)
return cls(ode_order=ode_order, cubature_rule=cubature_rule)

def begin_correction(self, x: Normal, /, vector_field, t, p):
raise NotImplementedError
Expand Down Expand Up @@ -261,18 +261,18 @@ def transform_sigma_points(self, rv: Normal):

# Multiply and shift the unit-points
m_marg1_x = rv.mean[0]
sigma_points_centered = self.cubature.points * r_marg1_x[None]
sigma_points_centered = self.cubature_rule.points * r_marg1_x[None]
sigma_points = m_marg1_x[None] + sigma_points_centered

# Scale the shifted points with square-root weights
_w = self.cubature.weights_sqrtm
_w = self.cubature_rule.weights_sqrtm
sigma_points_centered_normed = sigma_points_centered * _w
return sigma_points, sigma_points_centered, sigma_points_centered_normed

def center(self, fx):
fx_mean = self.cubature.weights_sqrtm**2 @ fx
fx_mean = self.cubature_rule.weights_sqrtm**2 @ fx
fx_centered = fx - fx_mean[None]
fx_centered_normed = fx_centered * self.cubature.weights_sqrtm
fx_centered_normed = fx_centered * self.cubature_rule.weights_sqrtm
return fx_mean, fx_centered, fx_centered_normed

def linearization_matrices(
Expand All @@ -284,7 +284,7 @@ def linearization_matrices(
# It seems to be different to Section VI.B in
# https://arxiv.org/pdf/2207.00426.pdf,
# because the implementation below avoids sqrt-down-dates
# pts_centered_normed = pts_centered * self.cubature.weights_sqrtm[:, None]
# pts_centered_normed = pts_centered * self.cubature_rule.weights_sqrtm[:, None]
_, (std_noi_mat, linop_mat) = _sqrtm.revert_conditional_noisefree(
R_X_F=pts_centered_normed[:, None], R_X=fx_centered_normed[:, None]
)
Expand Down
24 changes: 15 additions & 9 deletions probdiffeq/implementations/blockdiag/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,42 @@ class BlockDiagStatisticalFirstOrder(_collections.AbstractCorrection):
"""

def __init__(self, ode_shape, ode_order, cubature):
def __init__(self, ode_shape, ode_order, cubature_rule):
if ode_order > 1:
raise ValueError

super().__init__(ode_order=ode_order)
self.ode_shape = ode_shape

self._mm = _scalar.StatisticalFirstOrder(ode_order=ode_order, cubature=cubature)
self._mm = _scalar.StatisticalFirstOrder(
ode_order=ode_order, cubature_rule=cubature_rule
)

@property
def cubature(self):
return self._mm.cubature
def cubature_rule(self):
return self._mm.cubature_rule

def tree_flatten(self):
# todo: should this call super().tree_flatten()?
children = (self.cubature,)
children = (self.cubature_rule,)
aux = self.ode_order, self.ode_shape
return children, aux

@classmethod
def tree_unflatten(cls, aux, children):
(cubature,) = children
(cubature_rule,) = children
ode_order, ode_shape = aux
return cls(ode_order=ode_order, ode_shape=ode_shape, cubature=cubature)
return cls(
ode_order=ode_order, ode_shape=ode_shape, cubature_rule=cubature_rule
)

@classmethod
def from_params(cls, ode_shape, ode_order):
cubature_fn = cubature_module.ThirdOrderSpherical.from_params_blockdiag
cubature = cubature_fn(input_shape=ode_shape)
return cls(ode_shape=ode_shape, ode_order=ode_order, cubature=cubature)
cubature_rule = cubature_fn(input_shape=ode_shape)
return cls(
ode_shape=ode_shape, ode_order=ode_order, cubature_rule=cubature_rule
)

def begin_correction(self, extrapolated, /, vector_field, t, p):
# Vmap relevant functions
Expand Down
28 changes: 16 additions & 12 deletions probdiffeq/implementations/dense/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,14 @@ def __init__(self, ode_shape, ode_order, linearise_fn):
self.e1_vect = functools.partial(select_vect, i=self.ode_order)

@classmethod
def from_params(cls, ode_shape, ode_order, cubature=None):
if cubature is None:
make_rule_fn = cubature_module.ThirdOrderSpherical.from_params
cubature = make_rule_fn(input_shape=ode_shape)

linearise_fn = functools.partial(linearise_slr0, cubature_rule=cubature)
def from_params(
cls,
ode_shape,
ode_order,
cubature_rule_fn=cubature_module.ThirdOrderSpherical.from_params,
):
cubature_rule = cubature_rule_fn(input_shape=ode_shape)
linearise_fn = functools.partial(linearise_slr0, cubature_rule=cubature_rule)
return cls(ode_shape=ode_shape, ode_order=ode_order, linearise_fn=linearise_fn)

def tree_flatten(self):
Expand Down Expand Up @@ -352,12 +354,14 @@ def __init__(self, ode_shape, ode_order, linearise_fn):
self.e1_vect = functools.partial(select_vect, i=self.ode_order)

@classmethod
def from_params(cls, ode_shape, ode_order, cubature=None):
if cubature is None:
make_rule_fn = cubature_module.ThirdOrderSpherical.from_params
cubature = make_rule_fn(input_shape=ode_shape)

linearise_fn = functools.partial(linearise_slr1, cubature_rule=cubature)
def from_params(
cls,
ode_shape,
ode_order,
cubature_rule_fn=cubature_module.ThirdOrderSpherical.from_params,
):
cubature_rule = cubature_rule_fn(input_shape=ode_shape)
linearise_fn = functools.partial(linearise_slr1, cubature_rule=cubature_rule)
return cls(ode_shape=ode_shape, ode_order=ode_order, linearise_fn=linearise_fn)

def tree_flatten(self):
Expand Down
36 changes: 29 additions & 7 deletions probdiffeq/implementations/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax

from probdiffeq import cubature
from probdiffeq.implementations import _collections
from probdiffeq.implementations.blockdiag import corr as blockdiag_corr
from probdiffeq.implementations.blockdiag import extra as blockdiag_extra
Expand Down Expand Up @@ -40,6 +41,11 @@ def tree_unflatten(cls, _aux, children):
correction, extrapolation = children
return cls(correction=correction, extrapolation=extrapolation)

def __repr__(self):
name = self.__class__.__name__
n = self.extrapolation.num_derivatives
return f"<{name} with num_derivatives={n}>"


@jax.tree_util.register_pytree_node_class
class IsoTS0(AbstractImplementation[iso_corr.IsoTaylorZerothOrder, iso_extra.IsoIBM]):
Expand Down Expand Up @@ -68,14 +74,16 @@ class BlockDiagSLR1(
"""

@classmethod
def from_params(cls, *, ode_shape, cubature=None, ode_order=1, num_derivatives=4):
if cubature is None:
def from_params(
cls, *, ode_shape, cubature_rule=None, ode_order=1, num_derivatives=4
):
if cubature_rule is None:
correction = blockdiag_corr.BlockDiagStatisticalFirstOrder.from_params(
ode_shape=ode_shape, ode_order=ode_order
)
else:
correction = blockdiag_corr.BlockDiagStatisticalFirstOrder(
ode_shape=ode_shape, ode_order=ode_order, cubature=cubature
ode_shape=ode_shape, ode_order=ode_order, cubature_rule=cubature_rule
)
extrapolation = blockdiag_extra.BlockDiagIBM.from_params(
ode_shape=ode_shape, num_derivatives=num_derivatives
Expand Down Expand Up @@ -133,9 +141,16 @@ class DenseSLR1(
AbstractImplementation[dense_corr.DenseStatisticalFirstOrder, dense_extra.DenseIBM]
):
@classmethod
def from_params(cls, *, ode_shape, cubature=None, ode_order=1, num_derivatives=4):
def from_params(
cls,
*,
ode_shape,
cubature_rule_fn=cubature.ThirdOrderSpherical.from_params,
ode_order=1,
num_derivatives=4,
):
correction = dense_corr.DenseStatisticalFirstOrder.from_params(
ode_shape=ode_shape, ode_order=ode_order, cubature=cubature
ode_shape=ode_shape, ode_order=ode_order, cubature_rule_fn=cubature_rule_fn
)
extrapolation = dense_extra.DenseIBM.from_params(
ode_shape=ode_shape, num_derivatives=num_derivatives
Expand All @@ -159,9 +174,16 @@ class DenseSLR0(
"""

@classmethod
def from_params(cls, *, ode_shape, cubature=None, ode_order=1, num_derivatives=4):
def from_params(
cls,
*,
ode_shape,
cubature_rule_fn=cubature.ThirdOrderSpherical.from_params,
ode_order=1,
num_derivatives=4,
):
correction = dense_corr.DenseStatisticalZerothOrder.from_params(
ode_shape=ode_shape, ode_order=ode_order, cubature=cubature
ode_shape=ode_shape, ode_order=ode_order, cubature_rule_fn=cubature_rule_fn
)
extrapolation = dense_extra.DenseIBM.from_params(
ode_shape=ode_shape, num_derivatives=num_derivatives
Expand Down
7 changes: 7 additions & 0 deletions probdiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ def __init__(self, strategy, *, output_scale_sqrtm):
# todo: overwrite init_fn()?
self._output_scale_sqrtm = output_scale_sqrtm

def __repr__(self):
name = self.__class__.__name__
args = (
f"strategy={self.strategy}, output_scale_sqrtm={self._output_scale_sqrtm}"
)
return f"{name}({args})"

def step_fn(self, *, state, vector_field, dt, parameters):
# Pre-error-estimate steps
linearisation_pt, cache_ext = self.strategy.begin_extrapolation(
Expand Down
2 changes: 2 additions & 0 deletions probdiffeq/taylor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def jet_embedded(*c, degree):
fx, jvp_fn = jax.linearize(jet_embedded_deg, *taylor_coefficients)

# Compute the next set of coefficients.
# todo: can we jax.fori_loop() this loop?
# the running variable (cs_padded) should have constant size
cs = [(fx[deg - 1] / deg)]
for k in range(deg, min(2 * deg, num)):
# The Jacobian of the embedded jet is block-banded,
Expand Down
49 changes: 49 additions & 0 deletions probdiffeq/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Test utilities."""

from probdiffeq import solvers
from probdiffeq.implementations import recipes
from probdiffeq.strategies import filters


def generate_solver(
*,
solver_factory=solvers.MLESolver,
strategy_factory=filters.Filter,
impl_factory=recipes.IsoTS0.from_params,
**impl_factory_kwargs,
):
"""Generate a solver.
Examples
--------
>>> from jax.config import config
>>> config.update("jax_platform_name", "cpu")
>>> from probdiffeq import solvers
>>> from probdiffeq.implementations import recipes
>>> from probdiffeq.strategies import smoothers
>>> print(generate_solver())
MLESolver(strategy=Filter(implementation=<IsoTS0 with num_derivatives=4>))
>>> print(generate_solver(num_derivatives=1))
MLESolver(strategy=Filter(implementation=<IsoTS0 with num_derivatives=1>))
>>> print(generate_solver(solver_factory=solvers.DynamicSolver))
DynamicSolver(strategy=Filter(implementation=<IsoTS0 with num_derivatives=4>))
>>> impl_fcty = recipes.DenseTS1.from_params
>>> strat_fcty = smoothers.Smoother
>>> print(generate_solver(strategy_factory=strat_fcty, impl_factory=impl_fcty, ode_shape=(1,))) # noqa: E501
MLESolver(strategy=Smoother(implementation=<DenseTS1 with num_derivatives=4>))
"""
impl = impl_factory(**impl_factory_kwargs)
strat = strategy_factory(impl)

# I am not too happy with the need for this distinction below...

if solver_factory in [solvers.MLESolver, solvers.DynamicSolver]:
return solver_factory(strat)

scale_sqrtm = impl.extrapolation.init_output_scale_sqrtm()
return solver_factory(strat, output_scale_sqrtm=scale_sqrtm)
Loading

0 comments on commit afb3783

Please sign in to comment.