Skip to content

Commit

Permalink
Move cubature rules around (inside the module) to fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Jun 13, 2024
1 parent 147cd38 commit 9ac1d81
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 104 deletions.
200 changes: 101 additions & 99 deletions probdiffeq/solvers/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,101 @@
from probdiffeq.solvers import markov


class PositiveCubatureRule(containers.NamedTuple):
"""Cubature rule with positive weights."""

points: Array
weights_sqrtm: Array


def cubature_third_order_spherical(input_shape) -> PositiveCubatureRule:
"""Third-order spherical cubature integration."""
assert len(input_shape) <= 1
if len(input_shape) == 1:
(d,) = input_shape
points_mat, weights_sqrtm = _third_order_spherical_params(d=d)
return PositiveCubatureRule(points=points_mat, weights_sqrtm=weights_sqrtm)

# If input_shape == (), compute weights via input_shape=(1,)
# and 'squeeze' the points.
points_mat, weights_sqrtm = _third_order_spherical_params(d=1)
(S, _) = points_mat.shape
points = np.reshape(points_mat, (S,))
return PositiveCubatureRule(points=points, weights_sqrtm=weights_sqrtm)


def _third_order_spherical_params(*, d):
eye_d = np.eye(d) * np.sqrt(d)
pts = np.concatenate((eye_d, -1 * eye_d))
weights_sqrtm = np.ones((2 * d,)) / np.sqrt(2.0 * d)
return pts, weights_sqrtm


def cubature_unscented_transform(input_shape, r=1.0) -> PositiveCubatureRule:
"""Unscented transform."""
assert len(input_shape) <= 1
if len(input_shape) == 1:
(d,) = input_shape
points_mat, weights_sqrtm = _unscented_transform_params(d=d, r=r)
return PositiveCubatureRule(points=points_mat, weights_sqrtm=weights_sqrtm)

# If input_shape == (), compute weights via input_shape=(1,)
# and 'squeeze' the points.
points_mat, weights_sqrtm = _unscented_transform_params(d=1, r=r)
(S, _) = points_mat.shape
points = np.reshape(points_mat, (S,))
return PositiveCubatureRule(points=points, weights_sqrtm=weights_sqrtm)


def _unscented_transform_params(d, *, r):
eye_d = np.eye(d) * np.sqrt(d + r)
zeros = np.zeros((1, d))
pts = np.concatenate((eye_d, zeros, -1 * eye_d))
_scale = d + r
weights_sqrtm1 = np.ones((d,)) / np.sqrt(2.0 * _scale)
weights_sqrtm2 = np.sqrt(r / _scale)
weights_sqrtm = np.hstack((weights_sqrtm1, weights_sqrtm2, weights_sqrtm1))
return pts, weights_sqrtm


def cubature_gauss_hermite(input_shape, degree=5) -> PositiveCubatureRule:
"""(Statistician's) Gauss-Hermite cubature.
The number of cubature points is `prod(input_shape)**degree`.
"""
assert len(input_shape) == 1
(dim,) = input_shape

# Roots of the probabilist/statistician's Hermite polynomials (in Numpy...)
_roots = special.roots_hermitenorm(n=degree, mu=True)
pts, weights, sum_of_weights = _roots
weights = weights / sum_of_weights

# Transform into jax arrays and take square root of weights
pts = np.asarray(pts)
weights_sqrtm = np.sqrt(np.asarray(weights))

# Build a tensor grid and return class
tensor_pts = _tensor_points(pts, d=dim)
tensor_weights_sqrtm = _tensor_weights(weights_sqrtm, d=dim)
return PositiveCubatureRule(points=tensor_pts, weights_sqrtm=tensor_weights_sqrtm)


# how does this generalise to an input_shape instead of an input_dimension?
# via tree_map(lambda s: _tensor_points(x, s), input_shape)?


def _tensor_weights(*args, **kwargs):
mesh = _tensor_points(*args, **kwargs)
return np.prod_along_axis(mesh, axis=1)


def _tensor_points(x, /, *, d):
x_mesh = np.meshgrid(*([x] * d))
y_mesh = tree_util.tree_map(lambda s: np.reshape(s, (-1,)), x_mesh)
return np.stack(y_mesh).T


def prior_ibm(num_derivatives, output_scale=None):
"""Construct an adaptive(/continuous-time), multiply-integrated Wiener process."""
output_scale = output_scale or np.ones_like(impl.prototypes.output_scale())
Expand Down Expand Up @@ -150,9 +245,10 @@ def correction_ts1(*, ode_order=1) -> _ODEConstraintTaylor:
)


def correction_slr0(cubature_fun=None) -> _ODEConstraintStatistical:
def correction_slr0(
cubature_fun=cubature_third_order_spherical,
) -> _ODEConstraintStatistical:
"""Zeroth-order statistical linear regression."""
cubature_fun = cubature_fun or third_order_spherical
linearise_fun = impl.linearise.ode_statistical_1st(cubature_fun)
return _ODEConstraintStatistical(
ode_order=1,
Expand All @@ -161,107 +257,13 @@ def correction_slr0(cubature_fun=None) -> _ODEConstraintStatistical:
)


def correction_slr1(cubature_fun=None) -> _ODEConstraintStatistical:
def correction_slr1(
cubature_fun=cubature_third_order_spherical,
) -> _ODEConstraintStatistical:
"""First-order statistical linear regression."""
cubature_fun = cubature_fun or third_order_spherical
linearise_fun = impl.linearise.ode_statistical_0th(cubature_fun)
return _ODEConstraintStatistical(
ode_order=1,
linearise_fun=linearise_fun,
string_repr=f"<SLR0 with ode_order={1}>",
)


class PositiveCubatureRule(containers.NamedTuple):
"""Cubature rule with positive weights."""

points: Array
weights_sqrtm: Array


def third_order_spherical(input_shape) -> PositiveCubatureRule:
"""Third-order spherical cubature integration."""
assert len(input_shape) <= 1
if len(input_shape) == 1:
(d,) = input_shape
points_mat, weights_sqrtm = _third_order_spherical_params(d=d)
return PositiveCubatureRule(points=points_mat, weights_sqrtm=weights_sqrtm)

# If input_shape == (), compute weights via input_shape=(1,)
# and 'squeeze' the points.
points_mat, weights_sqrtm = _third_order_spherical_params(d=1)
(S, _) = points_mat.shape
points = np.reshape(points_mat, (S,))
return PositiveCubatureRule(points=points, weights_sqrtm=weights_sqrtm)


def _third_order_spherical_params(*, d):
eye_d = np.eye(d) * np.sqrt(d)
pts = np.concatenate((eye_d, -1 * eye_d))
weights_sqrtm = np.ones((2 * d,)) / np.sqrt(2.0 * d)
return pts, weights_sqrtm


def unscented_transform(input_shape, r=1.0) -> PositiveCubatureRule:
"""Unscented transform."""
assert len(input_shape) <= 1
if len(input_shape) == 1:
(d,) = input_shape
points_mat, weights_sqrtm = _unscented_transform_params(d=d, r=r)
return PositiveCubatureRule(points=points_mat, weights_sqrtm=weights_sqrtm)

# If input_shape == (), compute weights via input_shape=(1,)
# and 'squeeze' the points.
points_mat, weights_sqrtm = _unscented_transform_params(d=1, r=r)
(S, _) = points_mat.shape
points = np.reshape(points_mat, (S,))
return PositiveCubatureRule(points=points, weights_sqrtm=weights_sqrtm)


def _unscented_transform_params(d, *, r):
eye_d = np.eye(d) * np.sqrt(d + r)
zeros = np.zeros((1, d))
pts = np.concatenate((eye_d, zeros, -1 * eye_d))
_scale = d + r
weights_sqrtm1 = np.ones((d,)) / np.sqrt(2.0 * _scale)
weights_sqrtm2 = np.sqrt(r / _scale)
weights_sqrtm = np.hstack((weights_sqrtm1, weights_sqrtm2, weights_sqrtm1))
return pts, weights_sqrtm


def gauss_hermite(input_shape, degree=5) -> PositiveCubatureRule:
"""(Statistician's) Gauss-Hermite cubature.
The number of cubature points is `prod(input_shape)**degree`.
"""
assert len(input_shape) == 1
(dim,) = input_shape

# Roots of the probabilist/statistician's Hermite polynomials (in Numpy...)
_roots = special.roots_hermitenorm(n=degree, mu=True)
pts, weights, sum_of_weights = _roots
weights = weights / sum_of_weights

# Transform into jax arrays and take square root of weights
pts = np.asarray(pts)
weights_sqrtm = np.sqrt(np.asarray(weights))

# Build a tensor grid and return class
tensor_pts = _tensor_points(pts, d=dim)
tensor_weights_sqrtm = _tensor_weights(weights_sqrtm, d=dim)
return PositiveCubatureRule(points=tensor_pts, weights_sqrtm=tensor_weights_sqrtm)


# how does this generalise to an input_shape instead of an input_dimension?
# via tree_map(lambda s: _tensor_points(x, s), input_shape)?


def _tensor_weights(*args, **kwargs):
mesh = _tensor_points(*args, **kwargs)
return np.prod_along_axis(mesh, axis=1)


def _tensor_points(x, /, *, d):
x_mesh = np.meshgrid(*([x] * d))
y_mesh = tree_util.tree_map(lambda s: np.reshape(s, (-1,)), x_mesh)
return np.stack(y_mesh).T
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def case_slr1():
@testing.case()
def case_slr1_gauss_hermite():
try:
return components.correction_slr1(cubature_fun=components.gauss_hermite)
return components.correction_slr1(
cubature_fun=components.cubature_gauss_hermite
)
except NotImplementedError:
return "not_implemented"
raise RuntimeError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

def test_third_order_spherical_vs_unscented_transform_scalar_input():
"""Assert that UT with r=0 equals the third-order spherical rule."""
tos = components.third_order_spherical(input_shape=())
ut = components.unscented_transform(input_shape=(), r=0.0)
tos = components.cubature_third_order_spherical(input_shape=())
ut = components.cubature_unscented_transform(input_shape=(), r=0.0)
tos_points, tos_weights = tos.points, tos.weights_sqrtm
ut_points, ut_weights = ut.points, ut.weights_sqrtm
for x, y in [(ut_weights, tos_weights), (ut_points, tos_points)]:
Expand All @@ -18,8 +18,8 @@ def test_third_order_spherical_vs_unscented_transform_scalar_input():

def test_third_order_spherical_vs_unscented_transform(n=4):
"""Assert that UT with r=0 equals the third-order spherical rule."""
tos = components.third_order_spherical(input_shape=(n,))
ut = components.unscented_transform(input_shape=(n,), r=0.0)
tos = components.cubature_third_order_spherical(input_shape=(n,))
ut = components.cubature_unscented_transform(input_shape=(n,), r=0.0)
tos_points, tos_weights = tos.points, tos.weights_sqrtm
ut_points, ut_weights = ut.points, ut.weights_sqrtm
for x, y in [(ut_weights, tos_weights), (ut_points, tos_points)]:
Expand Down

0 comments on commit 9ac1d81

Please sign in to comment.