From 5efef333383b6ebbf812495fbba276394bfd9088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Mon, 17 Jun 2024 09:16:06 +0200 Subject: [PATCH] Condense the directory structure of `probdiffeq.impl` (#764) * Restructure impl._impl.py and move some components * Move DenseConditional and DenseStats * Move DenseTransform * Move DenseVariable * Move DenseHiddenModel * Migrate the old type checks and error messages * Move IsotropicPrototype * Move IsotropicSSMUtil * Move IsotropicVariable * Move IsotropicStats * Move IsotropicLinearisation * Move IsotropicConditional * Move IsotropicTransform * Move IsotropicHiddenModel and thereby complete moving isotropic implementations * Move ScalarPrototypes * Move ScalarSSMUtil * Move ScalarVariable * Move ScalarStats * Move ScalarLinearisation * Fix an import-error in impl.variable * Move ScalarConditional * Move ScalarTransform * Delete the (outdated) impl.scalar._transform.py * Move ScalarHiddenModel and thereby complete moving Scalar implementations * Move BlockDiagPrototype * Delete unused files in impl.blockdiag * Move BlockDiagSSMUtil * Move BlockDiagVariable * Move BlockDiagStats * Move BlockDiagLinearisation * Move BlockDiagConditional * Move BlockDiagTransform * Move BlockDiagHiddenModel and thereby complete moving BlockDiag implementations * Correct typing-issues in impl.impl * Improve the documentation of 'impl' --- probdiffeq/impl/__init__.py | 1 + probdiffeq/impl/_conditional.py | 192 +++++++++++++- probdiffeq/impl/_hidden_model.py | 162 +++++++++++- probdiffeq/impl/_impl.py | 252 +++++++++++-------- probdiffeq/impl/_linearise.py | 250 +++++++++++++++++- probdiffeq/impl/{dense => }/_normal.py | 0 probdiffeq/impl/_prototypes.py | 75 ++++++ probdiffeq/impl/_ssm_util.py | 249 +++++++++++++++++- probdiffeq/impl/_stats.py | 150 ++++++++++- probdiffeq/impl/_transform.py | 130 +++++++++- probdiffeq/impl/_variable.py | 73 +++++- probdiffeq/impl/blockdiag/__init__.py | 1 - probdiffeq/impl/blockdiag/_conditional.py | 66 ----- probdiffeq/impl/blockdiag/_hidden_model.py | 42 ---- probdiffeq/impl/blockdiag/_linearise.py | 31 --- probdiffeq/impl/blockdiag/_normal.py | 24 -- probdiffeq/impl/blockdiag/_prototypes.py | 22 -- probdiffeq/impl/blockdiag/_ssm_util.py | 70 ------ probdiffeq/impl/blockdiag/_stats.py | 41 --- probdiffeq/impl/blockdiag/_transform.py | 42 ---- probdiffeq/impl/blockdiag/_variable.py | 26 -- probdiffeq/impl/blockdiag/factorised_impl.py | 45 ---- probdiffeq/impl/dense/__init__.py | 1 - probdiffeq/impl/dense/_conditional.py | 46 ---- probdiffeq/impl/dense/_hidden_model.py | 57 ----- probdiffeq/impl/dense/_linearise.py | 170 ------------- probdiffeq/impl/dense/_prototypes.py | 22 -- probdiffeq/impl/dense/_ssm_util.py | 83 ------ probdiffeq/impl/dense/_stats.py | 40 --- probdiffeq/impl/dense/_transform.py | 36 --- probdiffeq/impl/dense/_variable.py | 17 -- probdiffeq/impl/dense/factorised_impl.py | 46 ---- probdiffeq/impl/isotropic/__init__.py | 1 - probdiffeq/impl/isotropic/_conditional.py | 51 ---- probdiffeq/impl/isotropic/_hidden_model.py | 39 --- probdiffeq/impl/isotropic/_linearise.py | 27 -- probdiffeq/impl/isotropic/_normal.py | 7 - probdiffeq/impl/isotropic/_prototypes.py | 22 -- probdiffeq/impl/isotropic/_ssm_util.py | 60 ----- probdiffeq/impl/isotropic/_stats.py | 46 ---- probdiffeq/impl/isotropic/_transform.py | 33 --- probdiffeq/impl/isotropic/_variable.py | 22 -- probdiffeq/impl/isotropic/factorised_impl.py | 45 ---- probdiffeq/impl/scalar/__init__.py | 1 - probdiffeq/impl/scalar/_conditional.py | 47 ---- probdiffeq/impl/scalar/_hidden_model.py | 39 --- probdiffeq/impl/scalar/_linearise.py | 25 -- probdiffeq/impl/scalar/_normal.py | 7 - probdiffeq/impl/scalar/_prototypes.py | 19 -- probdiffeq/impl/scalar/_ssm_util.py | 56 ----- probdiffeq/impl/scalar/_stats.py | 34 --- probdiffeq/impl/scalar/_transform.py | 37 --- probdiffeq/impl/scalar/_variable.py | 19 -- probdiffeq/impl/scalar/factorised_impl.py | 41 --- 54 files changed, 1422 insertions(+), 1718 deletions(-) rename probdiffeq/impl/{dense => }/_normal.py (100%) delete mode 100644 probdiffeq/impl/blockdiag/__init__.py delete mode 100644 probdiffeq/impl/blockdiag/_conditional.py delete mode 100644 probdiffeq/impl/blockdiag/_hidden_model.py delete mode 100644 probdiffeq/impl/blockdiag/_linearise.py delete mode 100644 probdiffeq/impl/blockdiag/_normal.py delete mode 100644 probdiffeq/impl/blockdiag/_prototypes.py delete mode 100644 probdiffeq/impl/blockdiag/_ssm_util.py delete mode 100644 probdiffeq/impl/blockdiag/_stats.py delete mode 100644 probdiffeq/impl/blockdiag/_transform.py delete mode 100644 probdiffeq/impl/blockdiag/_variable.py delete mode 100644 probdiffeq/impl/blockdiag/factorised_impl.py delete mode 100644 probdiffeq/impl/dense/__init__.py delete mode 100644 probdiffeq/impl/dense/_conditional.py delete mode 100644 probdiffeq/impl/dense/_hidden_model.py delete mode 100644 probdiffeq/impl/dense/_linearise.py delete mode 100644 probdiffeq/impl/dense/_prototypes.py delete mode 100644 probdiffeq/impl/dense/_ssm_util.py delete mode 100644 probdiffeq/impl/dense/_stats.py delete mode 100644 probdiffeq/impl/dense/_transform.py delete mode 100644 probdiffeq/impl/dense/_variable.py delete mode 100644 probdiffeq/impl/dense/factorised_impl.py delete mode 100644 probdiffeq/impl/isotropic/__init__.py delete mode 100644 probdiffeq/impl/isotropic/_conditional.py delete mode 100644 probdiffeq/impl/isotropic/_hidden_model.py delete mode 100644 probdiffeq/impl/isotropic/_linearise.py delete mode 100644 probdiffeq/impl/isotropic/_normal.py delete mode 100644 probdiffeq/impl/isotropic/_prototypes.py delete mode 100644 probdiffeq/impl/isotropic/_ssm_util.py delete mode 100644 probdiffeq/impl/isotropic/_stats.py delete mode 100644 probdiffeq/impl/isotropic/_transform.py delete mode 100644 probdiffeq/impl/isotropic/_variable.py delete mode 100644 probdiffeq/impl/isotropic/factorised_impl.py delete mode 100644 probdiffeq/impl/scalar/__init__.py delete mode 100644 probdiffeq/impl/scalar/_conditional.py delete mode 100644 probdiffeq/impl/scalar/_hidden_model.py delete mode 100644 probdiffeq/impl/scalar/_linearise.py delete mode 100644 probdiffeq/impl/scalar/_normal.py delete mode 100644 probdiffeq/impl/scalar/_prototypes.py delete mode 100644 probdiffeq/impl/scalar/_ssm_util.py delete mode 100644 probdiffeq/impl/scalar/_stats.py delete mode 100644 probdiffeq/impl/scalar/_transform.py delete mode 100644 probdiffeq/impl/scalar/_variable.py delete mode 100644 probdiffeq/impl/scalar/factorised_impl.py diff --git a/probdiffeq/impl/__init__.py b/probdiffeq/impl/__init__.py index 8882ce57..e2951cac 100644 --- a/probdiffeq/impl/__init__.py +++ b/probdiffeq/impl/__init__.py @@ -6,4 +6,5 @@ """State-space model implementation. Refer to the quickstart for information. +[Here is a link](https://pnkraemer.github.io/probdiffeq/examples_quickstart/easy_example/). """ diff --git a/probdiffeq/impl/_conditional.py b/probdiffeq/impl/_conditional.py index c110c410..c00c1ab3 100644 --- a/probdiffeq/impl/_conditional.py +++ b/probdiffeq/impl/_conditional.py @@ -1,6 +1,9 @@ """Conditionals.""" -from probdiffeq.backend import abc +from probdiffeq.backend import abc, functools, linalg +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal +from probdiffeq.util import cholesky_util, cond_util class ConditionalBackend(abc.ABC): @@ -19,3 +22,190 @@ def apply(self, x, conditional, /): @abc.abstractmethod def merge(self, cond1, cond2, /): raise NotImplementedError + + +class ScalarConditional(ConditionalBackend): + def marginalise(self, rv, conditional, /): + matrix, noise = conditional + + mean = matrix @ rv.mean + noise.mean + R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T) + cholesky_T = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack) + return _normal.Normal(mean, cholesky_T.T) + + def revert(self, rv, conditional, /): + matrix, noise = conditional + + r_ext, (r_bw_p, g_bw_p) = cholesky_util.revert_conditional( + R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T + ) + m_ext = matrix @ rv.mean + noise.mean + m_cond = rv.mean - g_bw_p @ m_ext + + marginal = _normal.Normal(m_ext, r_ext.T) + noise = _normal.Normal(m_cond, r_bw_p.T) + return marginal, cond_util.Conditional(g_bw_p, noise) + + def apply(self, x, conditional, /): + matrix, noise = conditional + matrix = np.squeeze(matrix) + return _normal.Normal(linalg.vector_dot(matrix, x) + noise.mean, noise.cholesky) + + def merge(self, previous, incoming, /): + A, b = previous + C, d = incoming + + g = A @ C + xi = A @ d.mean + b.mean + R_stack = ((A @ d.cholesky).T, b.cholesky.T) + Xi = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T + + noise = _normal.Normal(xi, Xi) + return cond_util.Conditional(g, noise) + + +class DenseConditional(ConditionalBackend): + def apply(self, x, conditional, /): + matrix, noise = conditional + return _normal.Normal(matrix @ x + noise.mean, noise.cholesky) + + def marginalise(self, rv, conditional, /): + matmul, noise = conditional + R_stack = ((matmul @ rv.cholesky).T, noise.cholesky.T) + cholesky_new = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T + return _normal.Normal(matmul @ rv.mean + noise.mean, cholesky_new) + + def merge(self, cond1, cond2, /): + A, b = cond1 + C, d = cond2 + + g = A @ C + xi = A @ d.mean + b.mean + Xi = cholesky_util.sum_of_sqrtm_factors( + R_stack=((A @ d.cholesky).T, b.cholesky.T) + ) + return cond_util.Conditional(g, _normal.Normal(xi, Xi.T)) + + def revert(self, rv, conditional, /): + matrix, noise = conditional + mean, cholesky = rv.mean, rv.cholesky + + # QR-decomposition + # (todo: rename revert_conditional_noisefree to + # revert_transformation_cov_sqrt()) + r_obs, (r_cor, gain) = cholesky_util.revert_conditional( + R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T + ) + + # Gather terms and return + mean_observed = matrix @ mean + noise.mean + m_cor = mean - gain @ mean_observed + corrected = _normal.Normal(m_cor, r_cor.T) + observed = _normal.Normal(mean_observed, r_obs.T) + return observed, cond_util.Conditional(gain, corrected) + + +class IsotropicConditional(ConditionalBackend): + def apply(self, x, conditional, /): + A, noise = conditional + # if the gain is qoi-to-hidden, the data is a (d,) array. + # this is problematic for the isotropic model unless we explicitly broadcast. + if np.ndim(x) == 1: + x = x[None, :] + return _normal.Normal(A @ x + noise.mean, noise.cholesky) + + def marginalise(self, rv, conditional, /): + matrix, noise = conditional + + mean = matrix @ rv.mean + noise.mean + + R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T) + cholesky = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T + return _normal.Normal(mean, cholesky) + + def merge(self, cond1, cond2, /): + A, b = cond1 + C, d = cond2 + + g = A @ C + xi = A @ d.mean + b.mean + R_stack = ((A @ d.cholesky).T, b.cholesky.T) + Xi = cholesky_util.sum_of_sqrtm_factors(R_stack).T + + noise = _normal.Normal(xi, Xi) + return cond_util.Conditional(g, noise) + + def revert(self, rv, conditional, /): + matrix, noise = conditional + + r_ext_p, (r_bw_p, gain) = cholesky_util.revert_conditional( + R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T + ) + extrapolated_cholesky = r_ext_p.T + corrected_cholesky = r_bw_p.T + + extrapolated_mean = matrix @ rv.mean + noise.mean + corrected_mean = rv.mean - gain @ extrapolated_mean + + extrapolated = _normal.Normal(extrapolated_mean, extrapolated_cholesky) + corrected = _normal.Normal(corrected_mean, corrected_cholesky) + return extrapolated, cond_util.Conditional(gain, corrected) + + +class BlockDiagConditional(ConditionalBackend): + def apply(self, x, conditional, /): + if np.ndim(x) == 1: + x = x[..., None] + + def apply_unbatch(m, s, n): + return _normal.Normal(m @ s + n.mean, n.cholesky) + + matrix, noise = conditional + return functools.vmap(apply_unbatch)(matrix, x, noise) + + def marginalise(self, rv, conditional, /): + matrix, noise = conditional + assert matrix.ndim == 3 + + mean = np.einsum("ijk,ik->ij", matrix, rv.mean) + noise.mean + + chol1 = _transpose(matrix @ rv.cholesky) + chol2 = _transpose(noise.cholesky) + R_stack = (chol1, chol2) + cholesky = functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack) + return _normal.Normal(mean, _transpose(cholesky)) + + def merge(self, cond1, cond2, /): + A, b = cond1 + C, d = cond2 + + g = A @ C + xi = (A @ d.mean[..., None])[..., 0] + b.mean + R_stack = (_transpose(A @ d.cholesky), _transpose(b.cholesky)) + Xi = _transpose(functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack)) + + noise = _normal.Normal(xi, Xi) + return cond_util.Conditional(g, noise) + + def revert(self, rv, conditional, /): + A, noise = conditional + rv_chol_upper = np.transpose(rv.cholesky, axes=(0, 2, 1)) + noise_chol_upper = np.transpose(noise.cholesky, axes=(0, 2, 1)) + A_rv_chol_upper = np.transpose(A @ rv.cholesky, axes=(0, 2, 1)) + + revert = functools.vmap(cholesky_util.revert_conditional) + r_obs, (r_cor, gain) = revert(A_rv_chol_upper, rv_chol_upper, noise_chol_upper) + + cholesky_obs = np.transpose(r_obs, axes=(0, 2, 1)) + cholesky_cor = np.transpose(r_cor, axes=(0, 2, 1)) + + # Gather terms and return + mean_observed = (A @ rv.mean[..., None])[..., 0] + noise.mean + m_cor = rv.mean - (gain @ (mean_observed[..., None]))[..., 0] + corrected = _normal.Normal(m_cor, cholesky_cor) + observed = _normal.Normal(mean_observed, cholesky_obs) + return observed, cond_util.Conditional(gain, corrected) + + +def _transpose(matrix): + return np.transpose(matrix, axes=(0, 2, 1)) diff --git a/probdiffeq/impl/_hidden_model.py b/probdiffeq/impl/_hidden_model.py index 892dca1b..474c6572 100644 --- a/probdiffeq/impl/_hidden_model.py +++ b/probdiffeq/impl/_hidden_model.py @@ -1,4 +1,7 @@ -from probdiffeq.backend import abc +from probdiffeq.backend import abc, functools +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal +from probdiffeq.util import cholesky_util, cond_util, linop_util class HiddenModelBackend(abc.ABC): @@ -17,3 +20,160 @@ def qoi_from_sample(self, sample, /): @abc.abstractmethod def conditional_to_derivative(self, i, standard_deviation): raise NotImplementedError + + +class ScalarHiddenModel(HiddenModelBackend): + def qoi(self, rv): + return rv.mean[..., 0] + + def marginal_nth_derivative(self, rv, i): + if rv.mean.ndim > 1: + return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( + rv, i + ) + + if i > rv.mean.shape[0]: + raise ValueError + + m = rv.mean[i] + c = rv.cholesky[[i], :] + chol = cholesky_util.triu_via_qr(c.T) + return _normal.Normal(np.reshape(m, ()), np.reshape(chol, ())) + + def qoi_from_sample(self, sample, /): + return sample[0] + + def conditional_to_derivative(self, i, standard_deviation): + def A(x): + return x[[i], ...] + + bias = np.zeros(()) + eye = np.eye(1) + noise = _normal.Normal(bias, standard_deviation * eye) + linop = linop_util.parametrised_linop(lambda s, _p: A(s)) + return cond_util.Conditional(linop, noise) + + +class DenseHiddenModel(HiddenModelBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def qoi(self, rv): + if np.ndim(rv.mean) > 1: + return functools.vmap(self.qoi)(rv) + mean_reshaped = np.reshape(rv.mean, (-1, *self.ode_shape), order="F") + return mean_reshaped[0] + + def marginal_nth_derivative(self, rv, i): + if rv.mean.ndim > 1: + return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( + rv, i + ) + + m = self._select(rv.mean, i) + c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i) + c = cholesky_util.triu_via_qr(c.T) + return _normal.Normal(m, c.T) + + def qoi_from_sample(self, sample, /): + sample_reshaped = np.reshape(sample, (-1, *self.ode_shape), order="F") + return sample_reshaped[0] + + # TODO: move to linearise.py? + def conditional_to_derivative(self, i, standard_deviation): + a0 = functools.partial(self._select, idx_or_slice=i) + + (d,) = self.ode_shape + bias = np.zeros((d,)) + eye = np.eye(d) + noise = _normal.Normal(bias, standard_deviation * eye) + linop = linop_util.parametrised_linop( + lambda s, _p: self._autobatch_linop(a0)(s) + ) + return cond_util.Conditional(linop, noise) + + def _select(self, x, /, idx_or_slice): + x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F") + if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]: + raise ValueError + return x_reshaped[idx_or_slice] + + @staticmethod + def _autobatch_linop(fun): + def fun_(x): + if np.ndim(x) > 1: + return functools.vmap(fun_, in_axes=1, out_axes=1)(x) + return fun(x) + + return fun_ + + +class IsotropicHiddenModel(HiddenModelBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def qoi(self, rv): + return rv.mean[..., 0, :] + + def marginal_nth_derivative(self, rv, i): + if np.ndim(rv.mean) > 2: + return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( + rv, i + ) + + if i > np.shape(rv.mean)[0]: + raise ValueError + + mean = rv.mean[i, :] + cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T + return _normal.Normal(mean, cholesky) + + def qoi_from_sample(self, sample, /): + return sample[0, :] + + def conditional_to_derivative(self, i, standard_deviation): + def A(x): + return x[[i], ...] + + bias = np.zeros(self.ode_shape) + eye = np.eye(1) + noise = _normal.Normal(bias, standard_deviation * eye) + linop = linop_util.parametrised_linop(lambda s, _p: A(s)) + return cond_util.Conditional(linop, noise) + + +class BlockDiagHiddenModel(HiddenModelBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def qoi(self, rv): + return rv.mean[..., 0] + + def marginal_nth_derivative(self, rv, i): + if np.ndim(rv.mean) > 2: + return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( + rv, i + ) + + if i > np.shape(rv.mean)[0]: + raise ValueError + + mean = rv.mean[:, i] + cholesky = functools.vmap(cholesky_util.triu_via_qr)( + (rv.cholesky[:, i, :])[..., None] + ) + cholesky = np.transpose(cholesky, axes=(0, 2, 1)) + return _normal.Normal(mean, cholesky) + + def qoi_from_sample(self, sample, /): + return sample[..., 0] + + def conditional_to_derivative(self, i, standard_deviation): + def A(x): + return x[:, [i], ...] + + bias = np.zeros((*self.ode_shape, 1)) + eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...] + noise = _normal.Normal(bias, standard_deviation * eye) + linop = linop_util.parametrised_linop(lambda s, _p: A(s)) + return cond_util.Conditional(linop, noise) diff --git a/probdiffeq/impl/_impl.py b/probdiffeq/impl/_impl.py index 8f9a0c93..00205e21 100644 --- a/probdiffeq/impl/_impl.py +++ b/probdiffeq/impl/_impl.py @@ -2,7 +2,7 @@ import warnings -from probdiffeq.backend import abc +from probdiffeq.backend import containers from probdiffeq.backend.typing import Optional from probdiffeq.impl import ( _conditional, @@ -16,82 +16,11 @@ ) -class FactorisedImpl(abc.ABC): - """Interface for the implementations provided by the backend.""" - - @abc.abstractmethod - def linearise(self) -> _linearise.LinearisationBackend: - raise NotImplementedError - - @abc.abstractmethod - def transform(self) -> _transform.TransformBackend: - raise NotImplementedError - - @abc.abstractmethod - def conditional(self) -> _conditional.ConditionalBackend: - raise NotImplementedError - - @abc.abstractmethod - def ssm_util(self) -> _ssm_util.SSMUtilBackend: - raise NotImplementedError - - @abc.abstractmethod - def prototypes(self) -> _prototypes.PrototypeBackend: - raise NotImplementedError - - @abc.abstractmethod - def variable(self) -> _variable.VariableBackend: - raise NotImplementedError - - @abc.abstractmethod - def hidden_model(self) -> _hidden_model.HiddenModelBackend: - raise NotImplementedError - - @abc.abstractmethod - def stats(self) -> _stats.StatsBackend: - raise NotImplementedError - - -def choose(which: str, /, *, ode_shape=None) -> FactorisedImpl: - # In this function, we import outside toplevel. - # - # Why? - # 1. To avoid cyclic imports - # 2. To avoid import errors if some backends require additional dependencies - # - if which == "scalar": - import probdiffeq.impl.scalar.factorised_impl - - return probdiffeq.impl.scalar.factorised_impl.Scalar() - - if ode_shape is None: - msg = "Please provide an ODE shape." - raise ValueError(msg) - - if which == "dense": - import probdiffeq.impl.dense.factorised_impl - - return probdiffeq.impl.dense.factorised_impl.Dense(ode_shape=ode_shape) - - if which == "isotropic": - import probdiffeq.impl.isotropic.factorised_impl - - return probdiffeq.impl.isotropic.factorised_impl.Isotropic(ode_shape=ode_shape) - - if which == "blockdiag": - import probdiffeq.impl.blockdiag.factorised_impl - - return probdiffeq.impl.blockdiag.factorised_impl.BlockDiag(ode_shape=ode_shape) - msg1 = f"Implementation '{which}' unknown. " - msg2 = "Choose an implementation out of {scalar, dense, isotropic, blockdiag}." - raise ValueError(msg1 + msg2) - - class Impl: """User-facing implementation 'package'. Wrap a factorised implementations and garnish it with error messages - and a "selection" functionality. + and a `select()` functionality. """ def __init__(self) -> None: @@ -111,56 +40,169 @@ def impl_name(self): @property def linearise(self) -> _linearise.LinearisationBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.linearise() + if self._fact is not None: + return self._fact.linearise + + raise ValueError(self.error_msg()) @property def conditional(self) -> _conditional.ConditionalBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.conditional() + if self._fact is not None: + return self._fact.conditional + raise ValueError(self.error_msg()) @property def transform(self) -> _transform.TransformBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.transform() + if self._fact is not None: + return self._fact.transform + raise ValueError(self.error_msg()) @property def ssm_util(self) -> _ssm_util.SSMUtilBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.ssm_util() + if self._fact is not None: + return self._fact.ssm_util + raise ValueError(self.error_msg()) @property def prototypes(self) -> _prototypes.PrototypeBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.prototypes() + if self._fact is not None: + return self._fact.prototypes + raise ValueError(self.error_msg()) @property def variable(self) -> _variable.VariableBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.variable() + if self._fact is not None: + return self._fact.variable + raise ValueError(self.error_msg()) @property def hidden_model(self) -> _hidden_model.HiddenModelBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.hidden_model() + if self._fact is not None: + return self._fact.hidden_model + raise ValueError(self.error_msg()) @property def stats(self) -> _stats.StatsBackend: - if self._fact is None: - msg = "Select a factorisation first." - raise ValueError(msg) - return self._fact.stats() + if self._fact is not None: + return self._fact.stats + raise ValueError(self.error_msg()) + + @staticmethod + def error_msg(): + return "Select a factorisation first." + + +@containers.dataclass +class FactorisedImpl: + prototypes: _prototypes.PrototypeBackend + ssm_util: _ssm_util.SSMUtilBackend + variable: _variable.VariableBackend + stats: _stats.StatsBackend + linearise: _linearise.LinearisationBackend + conditional: _conditional.ConditionalBackend + transform: _transform.TransformBackend + hidden_model: _hidden_model.HiddenModelBackend + + +def choose(which: str, /, *, ode_shape=None) -> FactorisedImpl: + if which == "scalar": + return _select_scalar() + + if ode_shape is None: + msg = "Please provide an ODE shape." + raise ValueError(msg) + + if which == "dense": + return _select_dense(ode_shape=ode_shape) + if which == "isotropic": + return _select_isotropic(ode_shape=ode_shape) + if which == "blockdiag": + return _select_blockdiag(ode_shape=ode_shape) + + msg1 = f"Implementation '{which}' unknown. " + msg2 = "Choose an implementation out of {scalar, dense, isotropic, blockdiag}." + raise ValueError(msg1 + msg2) + + +def _select_scalar() -> FactorisedImpl: + prototypes = _prototypes.ScalarPrototype() + ssm_util = _ssm_util.ScalarSSMUtil() + variable = _variable.ScalarVariable() + stats = _stats.ScalarStats() + linearise = _linearise.ScalarLinearisation() + conditional = _conditional.ScalarConditional() + transform = _transform.ScalarTransform() + hidden_model = _hidden_model.ScalarHiddenModel() + return FactorisedImpl( + prototypes=prototypes, + ssm_util=ssm_util, + variable=variable, + stats=stats, + linearise=linearise, + conditional=conditional, + transform=transform, + hidden_model=hidden_model, + ) + + +def _select_dense(*, ode_shape) -> FactorisedImpl: + prototypes = _prototypes.DensePrototype(ode_shape=ode_shape) + ssm_util = _ssm_util.DenseSSMUtil(ode_shape=ode_shape) + linearise = _linearise.DenseLinearisation(ode_shape=ode_shape) + stats = _stats.DenseStats(ode_shape=ode_shape) + conditional = _conditional.DenseConditional() + transform = _transform.DenseTransform() + variable = _variable.DenseVariable(ode_shape=ode_shape) + hidden_model = _hidden_model.DenseHiddenModel(ode_shape=ode_shape) + return FactorisedImpl( + linearise=linearise, + transform=transform, + conditional=conditional, + ssm_util=ssm_util, + prototypes=prototypes, + variable=variable, + hidden_model=hidden_model, + stats=stats, + ) + + +def _select_isotropic(*, ode_shape) -> FactorisedImpl: + prototypes = _prototypes.IsotropicPrototype(ode_shape=ode_shape) + ssm_util = _ssm_util.IsotropicSSMUtil(ode_shape=ode_shape) + variable = _variable.IsotropicVariable(ode_shape=ode_shape) + stats = _stats.IsotropicStats(ode_shape=ode_shape) + linearise = _linearise.IsotropicLinearisation() + conditional = _conditional.IsotropicConditional() + transform = _transform.IsotropicTransform() + hidden_model = _hidden_model.IsotropicHiddenModel(ode_shape=ode_shape) + return FactorisedImpl( + prototypes=prototypes, + ssm_util=ssm_util, + variable=variable, + stats=stats, + linearise=linearise, + conditional=conditional, + transform=transform, + hidden_model=hidden_model, + ) + + +def _select_blockdiag(*, ode_shape) -> FactorisedImpl: + prototypes = _prototypes.BlockDiagPrototype(ode_shape=ode_shape) + ssm_util = _ssm_util.BlockDiagSSMUtil(ode_shape=ode_shape) + variable = _variable.BlockDiagVariable(ode_shape=ode_shape) + stats = _stats.BlockDiagStats(ode_shape=ode_shape) + linearise = _linearise.BlockDiagLinearisation() + conditional = _conditional.BlockDiagConditional() + transform = _transform.BlockDiagTransform(ode_shape=ode_shape) + hidden_model = _hidden_model.BlockDiagHiddenModel(ode_shape=ode_shape) + return FactorisedImpl( + prototypes=prototypes, + ssm_util=ssm_util, + variable=variable, + stats=stats, + linearise=linearise, + conditional=conditional, + transform=transform, + hidden_model=hidden_model, + ) diff --git a/probdiffeq/impl/_linearise.py b/probdiffeq/impl/_linearise.py index 4d3c71ce..47631af2 100644 --- a/probdiffeq/impl/_linearise.py +++ b/probdiffeq/impl/_linearise.py @@ -1,4 +1,7 @@ -from probdiffeq.backend import abc +from probdiffeq.backend import abc, functools +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal +from probdiffeq.util import cholesky_util, linop_util class LinearisationBackend(abc.ABC): @@ -17,3 +20,248 @@ def ode_statistical_1st(self, cubature_fun): # ode_order > 1 not supported @abc.abstractmethod def ode_statistical_0th(self, cubature_fun): # ode_order > 1 not supported raise NotImplementedError + + +class ScalarLinearisation(LinearisationBackend): + def ode_taylor_0th(self, ode_order): + def linearise_fun_wrapped(fun, mean): + fx = self.ts0(fun, mean[:ode_order]) + return lambda s: s[ode_order], -fx + + return linearise_fun_wrapped + + def ode_taylor_1st(self, ode_order): + raise NotImplementedError + + def ode_statistical_1st(self, cubature_fun): + raise NotImplementedError + + def ode_statistical_0th(self, cubature_fun): + raise NotImplementedError + + @staticmethod + def ts0(fn, m): + return fn(m) + + +class DenseLinearisation(LinearisationBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def ode_taylor_0th(self, ode_order): + def linearise_fun_wrapped(fun, mean): + a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order)) + a1 = functools.partial(self._select_dy, idx_or_slice=ode_order) + + if np.shape(a0(mean)) != (expected_shape := (ode_order, *self.ode_shape)): + msg = f"{np.shape(a0(mean))} != {expected_shape}" + raise ValueError(msg) + + fx = self.ts0(fun, a0(mean)) + linop = linop_util.parametrised_linop( + lambda v, _p: self._autobatch_linop(a1)(v) + ) + return linop, -fx + + return linearise_fun_wrapped + + def ode_taylor_1st(self, ode_order): + def new(fun, mean, /): + a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order)) + a1 = functools.partial(self._select_dy, idx_or_slice=ode_order) + + if np.shape(a0(mean)) != (expected_shape := (ode_order, *self.ode_shape)): + msg = f"{np.shape(a0(mean))} != {expected_shape}" + raise ValueError(msg) + + jvp, fx = self.ts1(fun, a0(mean)) + + @self._autobatch_linop + def A(x): + x1 = a1(x) + x0 = a0(x) + return x1 - jvp(x0) + + linop = linop_util.parametrised_linop(lambda v, _p: A(v)) + return linop, -fx + + return new + + def ode_statistical_1st(self, cubature_fun): + cubature_rule = cubature_fun(input_shape=self.ode_shape) + linearise_fun = functools.partial(self.slr1, cubature_rule=cubature_rule) + + def new(fun, rv, /): + # Projection functions + a0 = self._autobatch_linop( + functools.partial(self._select_dy, idx_or_slice=0) + ) + a1 = self._autobatch_linop( + functools.partial(self._select_dy, idx_or_slice=1) + ) + + # Extract the linearisation point + m0, r_0_nonsquare = a0(rv.mean), a0(rv.cholesky) + r_0_square = cholesky_util.triu_via_qr(r_0_nonsquare.T) + linearisation_pt = _normal.Normal(m0, r_0_square.T) + + # Gather the variables and return + J, noise = linearise_fun(fun, linearisation_pt) + + def A(x): + return a1(x) - J @ a0(x) + + linop = linop_util.parametrised_linop(lambda v, _p: A(v)) + + mean, cov_lower = noise.mean, noise.cholesky + bias = _normal.Normal(-mean, cov_lower) + return linop, bias + + return new + + def ode_statistical_0th(self, cubature_fun): + cubature_rule = cubature_fun(input_shape=self.ode_shape) + linearise_fun = functools.partial(self.slr0, cubature_rule=cubature_rule) + + def new(fun, rv, /): + # Projection functions + a0 = self._autobatch_linop( + functools.partial(self._select_dy, idx_or_slice=0) + ) + a1 = self._autobatch_linop( + functools.partial(self._select_dy, idx_or_slice=1) + ) + + # Extract the linearisation point + m0, r_0_nonsquare = a0(rv.mean), a0(rv.cholesky) + r_0_square = cholesky_util.triu_via_qr(r_0_nonsquare.T) + linearisation_pt = _normal.Normal(m0, r_0_square.T) + + # Gather the variables and return + noise = linearise_fun(fun, linearisation_pt) + mean, cov_lower = noise.mean, noise.cholesky + bias = _normal.Normal(-mean, cov_lower) + linop = linop_util.parametrised_linop(lambda v, _p: a1(v)) + return linop, bias + + return new + + def _select_dy(self, x, idx_or_slice): + (d,) = self.ode_shape + x_reshaped = np.reshape(x, (-1, d), order="F") + return x_reshaped[idx_or_slice, ...] + + @staticmethod + def _autobatch_linop(fun): + def fun_(x): + if np.ndim(x) > 1: + return functools.vmap(fun_, in_axes=1, out_axes=1)(x) + return fun(x) + + return fun_ + + @staticmethod + def ts0(fn, m): + return fn(m) + + @staticmethod + def ts1(fn, m): + b, jvp = functools.linearize(fn, m) + return jvp, b - jvp(m) + + @staticmethod + def slr1(fn, x, *, cubature_rule): + """Linearise a function with first-order statistical linear regression.""" + # Create sigma-points + pts_centered = cubature_rule.points @ x.cholesky.T + pts = x.mean[None, :] + pts_centered + pts_centered_normed = pts_centered * cubature_rule.weights_sqrtm[:, None] + + # Evaluate the nonlinear function + fx = functools.vmap(fn)(pts) + fx_mean = cubature_rule.weights_sqrtm**2 @ fx + fx_centered = fx - fx_mean[None, :] + fx_centered_normed = fx_centered * cubature_rule.weights_sqrtm[:, None] + + # Compute statistical linear regression matrices + _, (cov_sqrtm_cond, linop_cond) = cholesky_util.revert_conditional_noisefree( + R_X_F=pts_centered_normed, R_X=fx_centered_normed + ) + mean_cond = fx_mean - linop_cond @ x.mean + rv_cond = _normal.Normal(mean_cond, cov_sqrtm_cond.T) + return linop_cond, rv_cond + + @staticmethod + def slr0(fn, x, *, cubature_rule): + """Linearise a function with zeroth-order statistical linear regression. + + !!! warning "Warning: highly EXPERIMENTAL feature!" + This feature is highly experimental. + There is no guarantee that it works correctly. + It might be deleted tomorrow + and without any deprecation policy. + + """ + # Create sigma-points + pts_centered = cubature_rule.points @ x.cholesky.T + pts = x.mean[None, :] + pts_centered + + # Evaluate the nonlinear function + fx = functools.vmap(fn)(pts) + fx_mean = cubature_rule.weights_sqrtm**2 @ fx + fx_centered = fx - fx_mean[None, :] + fx_centered_normed = fx_centered * cubature_rule.weights_sqrtm[:, None] + + cov_sqrtm = cholesky_util.triu_via_qr(fx_centered_normed) + + return _normal.Normal(fx_mean, cov_sqrtm.T) + + +class IsotropicLinearisation(LinearisationBackend): + def ode_taylor_1st(self, ode_order): + raise NotImplementedError + + def ode_taylor_0th(self, ode_order): + def linearise_fun_wrapped(fun, mean): + fx = self.ts0(fun, mean[:ode_order, ...]) + linop = linop_util.parametrised_linop(lambda s, _p: s[[ode_order], ...]) + return linop, -fx + + return linearise_fun_wrapped + + def ode_statistical_0th(self, cubature_fun): + raise NotImplementedError + + def ode_statistical_1st(self, cubature_fun): + raise NotImplementedError + + @staticmethod + def ts0(fn, m): + return fn(m) + + +class BlockDiagLinearisation(LinearisationBackend): + def ode_taylor_0th(self, ode_order): + def linearise_fun_wrapped(fun, mean): + m0 = mean[:, :ode_order] + fx = self.ts0(fun, m0.T) + + def a1(s): + return s[:, [ode_order], ...] + + return linop_util.parametrised_linop(lambda v, _p: a1(v)), -fx[:, None] + + return linearise_fun_wrapped + + def ode_taylor_1st(self, ode_order): + raise NotImplementedError + + def ode_statistical_0th(self, cubature_fun): + raise NotImplementedError + + def ode_statistical_1st(self, cubature_fun): + raise NotImplementedError + + @staticmethod + def ts0(fn, m): + return fn(m) diff --git a/probdiffeq/impl/dense/_normal.py b/probdiffeq/impl/_normal.py similarity index 100% rename from probdiffeq/impl/dense/_normal.py rename to probdiffeq/impl/_normal.py diff --git a/probdiffeq/impl/_prototypes.py b/probdiffeq/impl/_prototypes.py index e607ced2..5f12643f 100644 --- a/probdiffeq/impl/_prototypes.py +++ b/probdiffeq/impl/_prototypes.py @@ -1,4 +1,6 @@ from probdiffeq.backend import abc +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal class PrototypeBackend(abc.ABC): @@ -17,3 +19,76 @@ def error_estimate(self): @abc.abstractmethod def output_scale(self): raise NotImplementedError + + +class ScalarPrototype(PrototypeBackend): + def qoi(self): + return np.empty(()) + + def observed(self): + mean = np.empty(()) + cholesky = np.empty(()) + return _normal.Normal(mean, cholesky) + + def error_estimate(self): + return np.empty(()) + + def output_scale(self): + return np.empty(()) + + +class DensePrototype(PrototypeBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def qoi(self): + return np.empty(self.ode_shape) + + def observed(self): + mean = np.empty(self.ode_shape) + cholesky = np.empty(self.ode_shape + self.ode_shape) + return _normal.Normal(mean, cholesky) + + def error_estimate(self): + return np.empty(self.ode_shape) + + def output_scale(self): + return np.empty(()) + + +class IsotropicPrototype(PrototypeBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def qoi(self): + return np.empty(self.ode_shape) + + def observed(self): + mean = np.empty((1, *self.ode_shape)) + cholesky = np.empty(()) + return _normal.Normal(mean, cholesky) + + def error_estimate(self): + return np.empty(()) + + def output_scale(self): + return np.empty(()) + + +class BlockDiagPrototype(PrototypeBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def qoi(self): + return np.empty(self.ode_shape) + + def observed(self): + mean = np.empty((*self.ode_shape, 1)) + cholesky = np.empty((*self.ode_shape, 1, 1)) + return _normal.Normal(mean, cholesky) + + def error_estimate(self): + return np.empty(self.ode_shape) + + def output_scale(self): + return np.empty(self.ode_shape) diff --git a/probdiffeq/impl/_ssm_util.py b/probdiffeq/impl/_ssm_util.py index d0fa9027..86820c99 100644 --- a/probdiffeq/impl/_ssm_util.py +++ b/probdiffeq/impl/_ssm_util.py @@ -1,6 +1,9 @@ """SSM utilities.""" -from probdiffeq.backend import abc +from probdiffeq.backend import abc, functools +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal +from probdiffeq.util import cholesky_util, cond_util, ibm_util class SSMUtilBackend(abc.ABC): @@ -34,3 +37,247 @@ def identity_conditional(self, num_derivatives_per_ode_dimension, /): @abc.abstractmethod def standard_normal(self, num_derivatives_per_ode_dimension, /, output_scale): raise NotImplementedError + + +class ScalarSSMUtil(SSMUtilBackend): + def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): + if len(tcoeffs) != num_derivatives + 1: + msg1 = "The number of Taylor coefficients does not match " + msg2 = "the number of derivatives in the implementation." + raise ValueError(msg1 + msg2) + m0_matrix = np.stack(tcoeffs) + m0_corrected = np.reshape(m0_matrix, (-1,), order="F") + c_sqrtm0_corrected = np.zeros((num_derivatives + 1, num_derivatives + 1)) + return _normal.Normal(m0_corrected, c_sqrtm0_corrected) + + def preconditioner_apply(self, rv, p, /): + return _normal.Normal(p * rv.mean, p[:, None] * rv.cholesky) + + def preconditioner_apply_cond(self, cond, p, p_inv, /): + A, noise = cond + A = p[:, None] * A * p_inv[None, :] + noise = _normal.Normal(p * noise.mean, p[:, None] * noise.cholesky) + return cond_util.Conditional(A, noise) + + def ibm_transitions(self, num_derivatives, output_scale=1.0): + a, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale) + q0 = np.zeros((num_derivatives + 1,)) + noise = _normal.Normal(q0, q_sqrtm) + + precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) + + def discretise(dt): + p, p_inv = precon_fun(dt) + return cond_util.Conditional(a, noise), (p, p_inv) + + return discretise + + def identity_conditional(self, ndim, /): + transition = np.eye(ndim) + mean = np.zeros((ndim,)) + cov_sqrtm = np.zeros((ndim, ndim)) + noise = _normal.Normal(mean, cov_sqrtm) + return cond_util.Conditional(transition, noise) + + def standard_normal(self, ndim, /, output_scale): + mean = np.zeros((ndim,)) + cholesky = output_scale * np.eye(ndim) + return _normal.Normal(mean, cholesky) + + def update_mean(self, mean, x, /, num): + sum_updated = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) + return sum_updated / np.sqrt(num + 1) + + +class DenseSSMUtil(SSMUtilBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def ibm_transitions(self, num_derivatives, output_scale): + a, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale) + (d,) = self.ode_shape + eye_d = np.eye(d) + A = np.kron(eye_d, a) + Q = np.kron(eye_d, q_sqrtm) + + ndim = d * (num_derivatives + 1) + q0 = np.zeros((ndim,)) + noise = _normal.Normal(q0, Q) + + precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) + + def discretise(dt): + p, p_inv = precon_fun(dt) + p = np.tile(p, d) + p_inv = np.tile(p_inv, d) + return cond_util.Conditional(A, noise), (p, p_inv) + + return discretise + + def identity_conditional(self, ndim, /): + (d,) = self.ode_shape + n = ndim * d + + A = np.eye(n) + m = np.zeros((n,)) + C = np.zeros((n, n)) + return cond_util.Conditional(A, _normal.Normal(m, C)) + + def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): + if len(tcoeffs) != num_derivatives + 1: + msg1 = f"The number of Taylor coefficients {len(tcoeffs)} does not match " + msg2 = f"the number of derivatives {num_derivatives+1} in the solver." + raise ValueError(msg1 + msg2) + + if tcoeffs[0].shape != self.ode_shape: + msg = "The solver's ODE dimension does not match the initial condition." + raise ValueError(msg) + + m0_matrix = np.stack(tcoeffs) + m0_corrected = np.reshape(m0_matrix, (-1,), order="F") + + (ode_dim,) = self.ode_shape + ndim = (num_derivatives + 1) * ode_dim + c_sqrtm0_corrected = np.zeros((ndim, ndim)) + + return _normal.Normal(m0_corrected, c_sqrtm0_corrected) + + def preconditioner_apply(self, rv, p, /): + mean = p * rv.mean + cholesky = p[:, None] * rv.cholesky + return _normal.Normal(mean, cholesky) + + def preconditioner_apply_cond(self, cond, p, p_inv, /): + A, noise = cond + noise = self.preconditioner_apply(noise, p) + A = p[:, None] * A * p_inv[None, :] + return cond_util.Conditional(A, noise) + + def standard_normal(self, ndim, /, output_scale): + eye_n = np.eye(ndim) + eye_d = output_scale * np.eye(*self.ode_shape) + cholesky = np.kron(eye_d, eye_n) + mean = np.zeros((*self.ode_shape, ndim)).reshape((-1,), order="F") + return _normal.Normal(mean, cholesky) + + def update_mean(self, mean, x, /, num): + return cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) / np.sqrt( + num + 1 + ) + + +class IsotropicSSMUtil(SSMUtilBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def ibm_transitions(self, num_derivatives, output_scale): + A, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale) + q0 = np.zeros((num_derivatives + 1, *self.ode_shape)) + noise = _normal.Normal(q0, q_sqrtm) + precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) + + def discretise(dt): + p, p_inv = precon_fun(dt) + return cond_util.Conditional(A, noise), (p, p_inv) + + return discretise + + def identity_conditional(self, num_hidden_states_per_ode_dim, /): + m0 = np.zeros((num_hidden_states_per_ode_dim, *self.ode_shape)) + c0 = np.zeros((num_hidden_states_per_ode_dim, num_hidden_states_per_ode_dim)) + noise = _normal.Normal(m0, c0) + matrix = np.eye(num_hidden_states_per_ode_dim) + return cond_util.Conditional(matrix, noise) + + def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): + if len(tcoeffs) != num_derivatives + 1: + msg1 = f"The number of Taylor coefficients {len(tcoeffs)} does not match " + msg2 = f"the number of derivatives {num_derivatives+1} in the solver." + raise ValueError(msg1 + msg2) + + c_sqrtm0_corrected = np.zeros((num_derivatives + 1, num_derivatives + 1)) + m0_corrected = np.stack(tcoeffs) + return _normal.Normal(m0_corrected, c_sqrtm0_corrected) + + def preconditioner_apply(self, rv, p, /): + return _normal.Normal(p[:, None] * rv.mean, p[:, None] * rv.cholesky) + + def preconditioner_apply_cond(self, cond, p, p_inv, /): + A, noise = cond + + A_new = p[:, None] * A * p_inv[None, :] + + noise = _normal.Normal(p[:, None] * noise.mean, p[:, None] * noise.cholesky) + return cond_util.Conditional(A_new, noise) + + def standard_normal(self, num, /, output_scale): + mean = np.zeros((num, *self.ode_shape)) + cholesky = output_scale * np.eye(num) + return _normal.Normal(mean, cholesky) + + def update_mean(self, mean, x, /, num): + sum_updated = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) + return sum_updated / np.sqrt(num + 1) + + +class BlockDiagSSMUtil(SSMUtilBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def ibm_transitions(self, num_derivatives, output_scale): + system_matrices = functools.vmap(ibm_util.system_matrices_1d, in_axes=(None, 0)) + a, q_sqrtm = system_matrices(num_derivatives, output_scale) + + q0 = np.zeros((*self.ode_shape, num_derivatives + 1)) + noise = _normal.Normal(q0, q_sqrtm) + + precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) + + def discretise(dt): + p, p_inv = precon_fun(dt) + return (a, noise), (p, p_inv) + + return discretise + + def identity_conditional(self, ndim, /): + m0 = np.zeros((*self.ode_shape, ndim)) + c0 = np.zeros((*self.ode_shape, ndim, ndim)) + noise = _normal.Normal(m0, c0) + + matrix = np.ones((*self.ode_shape, 1, 1)) * np.eye(ndim, ndim)[None, ...] + return cond_util.Conditional(matrix, noise) + + def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): + if len(tcoeffs) != num_derivatives + 1: + msg1 = "The number of Taylor coefficients does not match " + msg2 = "the number of derivatives in the implementation." + raise ValueError(msg1 + msg2) + + cholesky_shape = (*self.ode_shape, num_derivatives + 1, num_derivatives + 1) + cholesky = np.zeros(cholesky_shape) + mean = np.stack(tcoeffs).T + return _normal.Normal(mean, cholesky) + + def preconditioner_apply(self, rv, p, /): + mean = p[None, :] * rv.mean + cholesky = p[None, :, None] * rv.cholesky + return _normal.Normal(mean, cholesky) + + def preconditioner_apply_cond(self, cond, p, p_inv, /): + A, noise = cond + A_new = p[None, :, None] * A * p_inv[None, None, :] + noise = self.preconditioner_apply(noise, p) + return cond_util.Conditional(A_new, noise) + + def standard_normal(self, ndim, output_scale): + mean = np.zeros((*self.ode_shape, ndim)) + cholesky = output_scale[:, None, None] * np.eye(ndim)[None, ...] + return _normal.Normal(mean, cholesky) + + def update_mean(self, mean, x, /, num): + if np.ndim(mean) > 0: + assert np.shape(mean) == np.shape(x) + return functools.vmap(self.update_mean, in_axes=(0, 0, None))(mean, x, num) + + sum_updated = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) + return sum_updated / np.sqrt(num + 1) diff --git a/probdiffeq/impl/_stats.py b/probdiffeq/impl/_stats.py index c61f114e..54bc4619 100644 --- a/probdiffeq/impl/_stats.py +++ b/probdiffeq/impl/_stats.py @@ -1,4 +1,6 @@ -from probdiffeq.backend import abc +from probdiffeq.backend import abc, functools, linalg +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal class StatsBackend(abc.ABC): @@ -21,3 +23,149 @@ def mean(self, rv): @abc.abstractmethod def sample_shape(self, rv): raise NotImplementedError + + +class ScalarStats(StatsBackend): + def mahalanobis_norm_relative(self, u, /, rv): + res_white = (u - rv.mean) / rv.cholesky + return np.abs(res_white) / np.sqrt(rv.mean.size) + + def logpdf(self, u, /, rv): + dx = u - rv.mean + w = linalg.solve_triangular(rv.cholesky.T, dx, trans="T") + + maha_term = linalg.vector_dot(w, w) + + diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2) + slogdet = np.sum(np.log(np.abs(diagonal))) + logdet_term = 2.0 * slogdet + return -0.5 * (logdet_term + maha_term + u.size * np.log(np.pi() * 2)) + + def standard_deviation(self, rv): + if rv.cholesky.ndim > 1: + return functools.vmap(self.standard_deviation)(rv) + + return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) + + def mean(self, rv): + return rv.mean + + def sample_shape(self, rv): + return rv.mean.shape + + +class DenseStats(StatsBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def mahalanobis_norm_relative(self, u, /, rv): + residual_white = linalg.solve_triangular( + rv.cholesky.T, u - rv.mean, lower=False, trans="T" + ) + mahalanobis = linalg.qr_r(residual_white[:, None]) + return np.reshape(np.abs(mahalanobis) / np.sqrt(rv.mean.size), ()) + + def logpdf(self, u, /, rv): + # The cholesky factor is triangular, so we compute a cheap slogdet. + diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2) + slogdet = np.sum(np.log(np.abs(diagonal))) + + dx = u - rv.mean + residual_white = linalg.solve_triangular(rv.cholesky.T, dx, trans="T") + x1 = linalg.vector_dot(residual_white, residual_white) + x2 = 2.0 * slogdet + x3 = u.size * np.log(np.pi() * 2) + return -0.5 * (x1 + x2 + x3) + + def mean(self, rv): + return rv.mean + + def standard_deviation(self, rv): + if rv.mean.ndim > 1: + return functools.vmap(self.standard_deviation)(rv) + + diag = np.einsum("ij,ij->i", rv.cholesky, rv.cholesky) + return np.sqrt(diag) + + def sample_shape(self, rv): + return rv.mean.shape + + +class IsotropicStats(StatsBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def mahalanobis_norm_relative(self, u, /, rv): + residual_white = (rv.mean - u) / rv.cholesky + residual_white_matrix = linalg.qr_r(residual_white.T) + return np.reshape(np.abs(residual_white_matrix) / np.sqrt(rv.mean.size), ()) + + def logpdf(self, u, /, rv): + # if the gain is qoi-to-hidden, the data is a (d,) array. + # this is problematic for the isotropic model unless we explicitly broadcast. + if np.ndim(u) == 1: + u = u[None, :] + + def logpdf_scalar(x, r): + dx = x - r.mean + w = linalg.solve_triangular(r.cholesky.T, dx, trans="T") + + maha_term = linalg.vector_dot(w, w) + + diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2) + slogdet = np.sum(np.log(np.abs(diagonal))) + logdet_term = 2.0 * slogdet + return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2)) + + # Batch in the "mean" dimension and sum the results. + rv_batch = _normal.Normal(1, None) + return np.sum(functools.vmap(logpdf_scalar, in_axes=(1, rv_batch))(u, rv)) + + def mean(self, rv): + return rv.mean + + def standard_deviation(self, rv): + if rv.cholesky.ndim > 1: + return functools.vmap(self.standard_deviation)(rv) + return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) + + def sample_shape(self, rv): + return rv.mean.shape + + +class BlockDiagStats(StatsBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def mahalanobis_norm_relative(self, u, /, rv): + # assumes rv.chol = (d,1,1) + # return array of norms! See calibration + mean = np.reshape(rv.mean, self.ode_shape) + cholesky = np.reshape(rv.cholesky, self.ode_shape) + return (mean - u) / cholesky / np.sqrt(mean.size) + + def logpdf(self, u, /, rv): + def logpdf_scalar(x, r): + dx = x - r.mean + w = linalg.solve_triangular(r.cholesky.T, dx, trans="T") + + maha_term = linalg.vector_dot(w, w) + + diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2) + slogdet = np.sum(np.log(np.abs(diagonal))) + logdet_term = 2.0 * slogdet + return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2)) + + return np.sum(functools.vmap(logpdf_scalar)(u, rv)) + + def mean(self, rv): + return rv.mean + + def sample_shape(self, rv): + return rv.mean.shape + + def standard_deviation(self, rv): + if rv.cholesky.ndim > 1: + return functools.vmap(self.standard_deviation)(rv) + + return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) diff --git a/probdiffeq/impl/_transform.py b/probdiffeq/impl/_transform.py index 2873ee1a..53a0861d 100644 --- a/probdiffeq/impl/_transform.py +++ b/probdiffeq/impl/_transform.py @@ -1,4 +1,13 @@ -from probdiffeq.backend import abc +from probdiffeq.backend import abc, containers, functools +from probdiffeq.backend import numpy as np +from probdiffeq.backend.typing import Array, Callable +from probdiffeq.impl import _normal +from probdiffeq.util import cholesky_util, cond_util + + +class Transformation(containers.NamedTuple): + matmul: Callable + bias: Array class TransformBackend(abc.ABC): @@ -9,3 +18,122 @@ def marginalise(self, rv, transformation, /): @abc.abstractmethod def revert(self, rv, transformation, /): raise NotImplementedError + + +class ScalarTransform(TransformBackend): + def marginalise(self, rv, transformation, /): + # currently, assumes that A(rv.cholesky) is a vector, not a matrix. + matmul, b = transformation + cholesky_new = cholesky_util.triu_via_qr(matmul(rv.cholesky)[:, None]) + cholesky_new_squeezed = np.reshape(cholesky_new, ()) + return _normal.Normal(matmul(rv.mean) + b, cholesky_new_squeezed) + + def revert(self, rv, transformation, /): + # Assumes that A maps a vector to a scalar... + + # Extract information + A, b = transformation + + # QR-decomposition + # (todo: rename revert_conditional_noisefree + # to transformation_revert_cov_sqrt()) + r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree( + R_X_F=A(rv.cholesky)[:, None], R_X=rv.cholesky.T + ) + cholesky_obs = np.reshape(r_obs, ()) + cholesky_cor = r_cor.T + gain = np.squeeze_along_axis(gain, axis=-1) + + # Gather terms and return + m_cor = rv.mean - gain * (A(rv.mean) + b) + corrected = _normal.Normal(m_cor, cholesky_cor) + observed = _normal.Normal(A(rv.mean) + b, cholesky_obs) + return observed, cond_util.Conditional(gain, corrected) + + +class DenseTransform(TransformBackend): + def marginalise(self, rv, transformation, /): + A, b = transformation + cholesky_new = cholesky_util.triu_via_qr((A @ rv.cholesky).T).T + return _normal.Normal(A @ rv.mean + b, cholesky_new) + + def revert(self, rv, transformation, /): + A, b = transformation + mean, cholesky = rv.mean, rv.cholesky + + # QR-decomposition + # (todo: rename revert_conditional_noisefree to + # revert_transformation_cov_sqrt()) + r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree( + R_X_F=(A @ cholesky).T, R_X=cholesky.T + ) + + # Gather terms and return + m_cor = mean - gain @ (A @ mean + b) + corrected = _normal.Normal(m_cor, r_cor.T) + observed = _normal.Normal(A @ mean + b, r_obs.T) + return observed, cond_util.Conditional(gain, corrected) + + +class IsotropicTransform(TransformBackend): + def marginalise(self, rv, transformation, /): + A, b = transformation + mean, cholesky = rv.mean, rv.cholesky + cholesky_new = cholesky_util.triu_via_qr((A @ cholesky).T) + cholesky_squeezed = np.reshape(cholesky_new, ()) + return _normal.Normal((A @ mean) + b, cholesky_squeezed) + + def revert(self, rv, transformation, /): + A, b = transformation + mean, cholesky = rv.mean, rv.cholesky + + # QR-decomposition + # (todo: rename revert_conditional_noisefree + # to revert_transformation_cov_sqrt()) + r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree( + R_X_F=(A @ cholesky).T, R_X=cholesky.T + ) + cholesky_obs = np.reshape(r_obs, ()) + cholesky_cor = r_cor.T + + # Gather terms and return + mean_observed = A @ mean + b + m_cor = mean - gain * mean_observed + corrected = _normal.Normal(m_cor, cholesky_cor) + observed = _normal.Normal(mean_observed, cholesky_obs) + return observed, cond_util.Conditional(gain, corrected) + + +class BlockDiagTransform(TransformBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def marginalise(self, rv, transformation, /): + A, b = transformation + mean, cholesky = rv.mean, rv.cholesky + + A_cholesky = A @ cholesky + cholesky = functools.vmap(cholesky_util.triu_via_qr)(_transpose(A_cholesky)) + mean = A @ mean + b + return _normal.Normal(mean, cholesky) + + def revert(self, rv, transformation, /): + A, bias = transformation + cholesky_upper = np.transpose(rv.cholesky, axes=(0, -1, -2)) + A_cholesky_upper = _transpose(A @ rv.cholesky) + + revert_fun = functools.vmap(cholesky_util.revert_conditional_noisefree) + r_obs, (r_cor, gain) = revert_fun(A_cholesky_upper, cholesky_upper) + cholesky_obs = _transpose(r_obs) + cholesky_cor = _transpose(r_cor) + + # Gather terms and return + mean_observed = (A @ rv.mean) + bias + m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0] + corrected = _normal.Normal(m_cor, cholesky_cor) + observed = _normal.Normal(mean_observed, cholesky_obs) + return observed, cond_util.Conditional(gain, corrected) + + +def _transpose(arr, /): + return np.transpose(arr, axes=(0, 2, 1)) diff --git a/probdiffeq/impl/_variable.py b/probdiffeq/impl/_variable.py index 61826dc6..0ea53b73 100644 --- a/probdiffeq/impl/_variable.py +++ b/probdiffeq/impl/_variable.py @@ -1,4 +1,6 @@ -from probdiffeq.backend import abc +from probdiffeq.backend import abc, functools +from probdiffeq.backend import numpy as np +from probdiffeq.impl import _normal class VariableBackend(abc.ABC): @@ -13,3 +15,72 @@ def rescale_cholesky(self, rv, factor, /): @abc.abstractmethod def transform_unit_sample(self, unit_sample, /, rv): raise NotImplementedError + + +class ScalarVariable(VariableBackend): + def rescale_cholesky(self, rv, factor): + if np.ndim(factor) > 0: + return functools.vmap(self.rescale_cholesky)(rv, factor) + return _normal.Normal(rv.mean, factor * rv.cholesky) + + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + rv.cholesky @ unit_sample + + def to_multivariate_normal(self, rv): + return rv.mean, rv.cholesky @ rv.cholesky.T + + +class DenseVariable(VariableBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def rescale_cholesky(self, rv, factor, /): + cholesky = factor[..., None, None] * rv.cholesky + return _normal.Normal(rv.mean, cholesky) + + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + rv.cholesky @ unit_sample + + def to_multivariate_normal(self, rv): + return rv.mean, rv.cholesky @ rv.cholesky.T + + +class IsotropicVariable(VariableBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def rescale_cholesky(self, rv, factor, /): + cholesky = factor[..., None, None] * rv.cholesky + return _normal.Normal(rv.mean, cholesky) + + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + rv.cholesky @ unit_sample + + def to_multivariate_normal(self, rv): + eye_d = np.eye(*self.ode_shape) + cov = rv.cholesky @ rv.cholesky.T + cov = np.kron(eye_d, cov) + mean = rv.mean.reshape((-1,), order="F") + return (mean, cov) + + +class BlockDiagVariable(VariableBackend): + def __init__(self, ode_shape): + self.ode_shape = ode_shape + + def rescale_cholesky(self, rv, factor, /): + cholesky = factor[..., None, None] * rv.cholesky + return _normal.Normal(rv.mean, cholesky) + + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + (rv.cholesky @ unit_sample[..., None])[..., 0] + + def to_multivariate_normal(self, rv): + mean = np.reshape(rv.mean.T, (-1,), order="F") + cov = np.block_diag(self._cov_dense(rv.cholesky)) + return (mean, cov) + + def _cov_dense(self, cholesky): + if cholesky.ndim > 2: + return functools.vmap(self._cov_dense)(cholesky) + return cholesky @ cholesky.T diff --git a/probdiffeq/impl/blockdiag/__init__.py b/probdiffeq/impl/blockdiag/__init__.py deleted file mode 100644 index 9151ca98..00000000 --- a/probdiffeq/impl/blockdiag/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Block-diagonal models.""" diff --git a/probdiffeq/impl/blockdiag/_conditional.py b/probdiffeq/impl/blockdiag/_conditional.py deleted file mode 100644 index 796ab426..00000000 --- a/probdiffeq/impl/blockdiag/_conditional.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Conditional implementation.""" - -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _conditional -from probdiffeq.impl.blockdiag import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class ConditionalBackend(_conditional.ConditionalBackend): - def apply(self, x, conditional, /): - if np.ndim(x) == 1: - x = x[..., None] - - def apply_unbatch(m, s, n): - return _normal.Normal(m @ s + n.mean, n.cholesky) - - matrix, noise = conditional - return functools.vmap(apply_unbatch)(matrix, x, noise) - - def marginalise(self, rv, conditional, /): - matrix, noise = conditional - assert matrix.ndim == 3 - - mean = np.einsum("ijk,ik->ij", matrix, rv.mean) + noise.mean - - chol1 = _transpose(matrix @ rv.cholesky) - chol2 = _transpose(noise.cholesky) - R_stack = (chol1, chol2) - cholesky = functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack) - return _normal.Normal(mean, _transpose(cholesky)) - - def merge(self, cond1, cond2, /): - A, b = cond1 - C, d = cond2 - - g = A @ C - xi = (A @ d.mean[..., None])[..., 0] + b.mean - R_stack = (_transpose(A @ d.cholesky), _transpose(b.cholesky)) - Xi = _transpose(functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack)) - - noise = _normal.Normal(xi, Xi) - return cond_util.Conditional(g, noise) - - def revert(self, rv, conditional, /): - A, noise = conditional - rv_chol_upper = np.transpose(rv.cholesky, axes=(0, 2, 1)) - noise_chol_upper = np.transpose(noise.cholesky, axes=(0, 2, 1)) - A_rv_chol_upper = np.transpose(A @ rv.cholesky, axes=(0, 2, 1)) - - revert = functools.vmap(cholesky_util.revert_conditional) - r_obs, (r_cor, gain) = revert(A_rv_chol_upper, rv_chol_upper, noise_chol_upper) - - cholesky_obs = np.transpose(r_obs, axes=(0, 2, 1)) - cholesky_cor = np.transpose(r_cor, axes=(0, 2, 1)) - - # Gather terms and return - mean_observed = (A @ rv.mean[..., None])[..., 0] + noise.mean - m_cor = rv.mean - (gain @ (mean_observed[..., None]))[..., 0] - corrected = _normal.Normal(m_cor, cholesky_cor) - observed = _normal.Normal(mean_observed, cholesky_obs) - return observed, cond_util.Conditional(gain, corrected) - - -def _transpose(matrix): - return np.transpose(matrix, axes=(0, 2, 1)) diff --git a/probdiffeq/impl/blockdiag/_hidden_model.py b/probdiffeq/impl/blockdiag/_hidden_model.py deleted file mode 100644 index 14d23aeb..00000000 --- a/probdiffeq/impl/blockdiag/_hidden_model.py +++ /dev/null @@ -1,42 +0,0 @@ -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _hidden_model -from probdiffeq.impl.blockdiag import _normal -from probdiffeq.util import cholesky_util, cond_util, linop_util - - -class HiddenModelBackend(_hidden_model.HiddenModelBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def qoi(self, rv): - return rv.mean[..., 0] - - def marginal_nth_derivative(self, rv, i): - if np.ndim(rv.mean) > 2: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - if i > np.shape(rv.mean)[0]: - raise ValueError - - mean = rv.mean[:, i] - cholesky = functools.vmap(cholesky_util.triu_via_qr)( - (rv.cholesky[:, i, :])[..., None] - ) - cholesky = np.transpose(cholesky, axes=(0, 2, 1)) - return _normal.Normal(mean, cholesky) - - def qoi_from_sample(self, sample, /): - return sample[..., 0] - - def conditional_to_derivative(self, i, standard_deviation): - def A(x): - return x[:, [i], ...] - - bias = np.zeros((*self.ode_shape, 1)) - eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...] - noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: A(s)) - return cond_util.Conditional(linop, noise) diff --git a/probdiffeq/impl/blockdiag/_linearise.py b/probdiffeq/impl/blockdiag/_linearise.py deleted file mode 100644 index c6771672..00000000 --- a/probdiffeq/impl/blockdiag/_linearise.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Linearisation.""" - -from probdiffeq.impl import _linearise -from probdiffeq.util import linop_util - - -class LinearisationBackend(_linearise.LinearisationBackend): - def ode_taylor_0th(self, ode_order): - def linearise_fun_wrapped(fun, mean): - m0 = mean[:, :ode_order] - fx = ts0(fun, m0.T) - - def a1(s): - return s[:, [ode_order], ...] - - return linop_util.parametrised_linop(lambda v, _p: a1(v)), -fx[:, None] - - return linearise_fun_wrapped - - def ode_taylor_1st(self, ode_order): - raise NotImplementedError - - def ode_statistical_0th(self, cubature_fun): - raise NotImplementedError - - def ode_statistical_1st(self, cubature_fun): - raise NotImplementedError - - -def ts0(fn, m): - return fn(m) diff --git a/probdiffeq/impl/blockdiag/_normal.py b/probdiffeq/impl/blockdiag/_normal.py deleted file mode 100644 index dbb1120b..00000000 --- a/probdiffeq/impl/blockdiag/_normal.py +++ /dev/null @@ -1,24 +0,0 @@ -from probdiffeq.backend import tree_util - - -class Normal: - def __init__(self, mean, cholesky): - self.mean = mean - self.cholesky = cholesky - - def __repr__(self): - return f"Normal({self.mean}, cholesky={self.cholesky})" - - -def _flatten(normal): - children = (normal.mean, normal.cholesky) - aux = () - return children, aux - - -def _unflatten(_aux, children): - (mean, cholesky) = children - return Normal(mean, cholesky) - - -tree_util.register_pytree_node(Normal, _flatten, _unflatten) diff --git a/probdiffeq/impl/blockdiag/_prototypes.py b/probdiffeq/impl/blockdiag/_prototypes.py deleted file mode 100644 index cc218f8c..00000000 --- a/probdiffeq/impl/blockdiag/_prototypes.py +++ /dev/null @@ -1,22 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _prototypes -from probdiffeq.impl.blockdiag import _normal - - -class PrototypeBackend(_prototypes.PrototypeBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def qoi(self): - return np.empty(self.ode_shape) - - def observed(self): - mean = np.empty((*self.ode_shape, 1)) - cholesky = np.empty((*self.ode_shape, 1, 1)) - return _normal.Normal(mean, cholesky) - - def error_estimate(self): - return np.empty(self.ode_shape) - - def output_scale(self): - return np.empty(self.ode_shape) diff --git a/probdiffeq/impl/blockdiag/_ssm_util.py b/probdiffeq/impl/blockdiag/_ssm_util.py deleted file mode 100644 index 490a2863..00000000 --- a/probdiffeq/impl/blockdiag/_ssm_util.py +++ /dev/null @@ -1,70 +0,0 @@ -"""State-space model utilities.""" - -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _ssm_util -from probdiffeq.impl.blockdiag import _normal -from probdiffeq.util import cholesky_util, cond_util, ibm_util - - -class SSMUtilBackend(_ssm_util.SSMUtilBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def ibm_transitions(self, num_derivatives, output_scale): - system_matrices = functools.vmap(ibm_util.system_matrices_1d, in_axes=(None, 0)) - a, q_sqrtm = system_matrices(num_derivatives, output_scale) - - q0 = np.zeros((*self.ode_shape, num_derivatives + 1)) - noise = _normal.Normal(q0, q_sqrtm) - - precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) - - def discretise(dt): - p, p_inv = precon_fun(dt) - return (a, noise), (p, p_inv) - - return discretise - - def identity_conditional(self, ndim, /): - m0 = np.zeros((*self.ode_shape, ndim)) - c0 = np.zeros((*self.ode_shape, ndim, ndim)) - noise = _normal.Normal(m0, c0) - - matrix = np.ones((*self.ode_shape, 1, 1)) * np.eye(ndim, ndim)[None, ...] - return cond_util.Conditional(matrix, noise) - - def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): - if len(tcoeffs) != num_derivatives + 1: - msg1 = "The number of Taylor coefficients does not match " - msg2 = "the number of derivatives in the implementation." - raise ValueError(msg1 + msg2) - - cholesky_shape = (*self.ode_shape, num_derivatives + 1, num_derivatives + 1) - cholesky = np.zeros(cholesky_shape) - mean = np.stack(tcoeffs).T - return _normal.Normal(mean, cholesky) - - def preconditioner_apply(self, rv, p, /): - mean = p[None, :] * rv.mean - cholesky = p[None, :, None] * rv.cholesky - return _normal.Normal(mean, cholesky) - - def preconditioner_apply_cond(self, cond, p, p_inv, /): - A, noise = cond - A_new = p[None, :, None] * A * p_inv[None, None, :] - noise = self.preconditioner_apply(noise, p) - return cond_util.Conditional(A_new, noise) - - def standard_normal(self, ndim, output_scale): - mean = np.zeros((*self.ode_shape, ndim)) - cholesky = output_scale[:, None, None] * np.eye(ndim)[None, ...] - return _normal.Normal(mean, cholesky) - - def update_mean(self, mean, x, /, num): - if np.ndim(mean) > 0: - assert np.shape(mean) == np.shape(x) - return functools.vmap(self.update_mean, in_axes=(0, 0, None))(mean, x, num) - - sum_updated = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) - return sum_updated / np.sqrt(num + 1) diff --git a/probdiffeq/impl/blockdiag/_stats.py b/probdiffeq/impl/blockdiag/_stats.py deleted file mode 100644 index 652569ef..00000000 --- a/probdiffeq/impl/blockdiag/_stats.py +++ /dev/null @@ -1,41 +0,0 @@ -from probdiffeq.backend import functools, linalg -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _stats - - -class StatsBackend(_stats.StatsBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def mahalanobis_norm_relative(self, u, /, rv): - # assumes rv.chol = (d,1,1) - # return array of norms! See calibration - mean = np.reshape(rv.mean, self.ode_shape) - cholesky = np.reshape(rv.cholesky, self.ode_shape) - return (mean - u) / cholesky / np.sqrt(mean.size) - - def logpdf(self, u, /, rv): - def logpdf_scalar(x, r): - dx = x - r.mean - w = linalg.solve_triangular(r.cholesky.T, dx, trans="T") - - maha_term = linalg.vector_dot(w, w) - - diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2) - slogdet = np.sum(np.log(np.abs(diagonal))) - logdet_term = 2.0 * slogdet - return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2)) - - return np.sum(functools.vmap(logpdf_scalar)(u, rv)) - - def mean(self, rv): - return rv.mean - - def sample_shape(self, rv): - return rv.mean.shape - - def standard_deviation(self, rv): - if rv.cholesky.ndim > 1: - return functools.vmap(self.standard_deviation)(rv) - - return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) diff --git a/probdiffeq/impl/blockdiag/_transform.py b/probdiffeq/impl/blockdiag/_transform.py deleted file mode 100644 index 3df78efa..00000000 --- a/probdiffeq/impl/blockdiag/_transform.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Random-variable transformation.""" - -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _transform -from probdiffeq.impl.blockdiag import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class TransformBackend(_transform.TransformBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def marginalise(self, rv, transformation, /): - A, b = transformation - mean, cholesky = rv.mean, rv.cholesky - - A_cholesky = A @ cholesky - cholesky = functools.vmap(cholesky_util.triu_via_qr)(_transpose(A_cholesky)) - mean = A @ mean + b - return _normal.Normal(mean, cholesky) - - def revert(self, rv, transformation, /): - A, bias = transformation - cholesky_upper = np.transpose(rv.cholesky, axes=(0, -1, -2)) - A_cholesky_upper = _transpose(A @ rv.cholesky) - - revert_fun = functools.vmap(cholesky_util.revert_conditional_noisefree) - r_obs, (r_cor, gain) = revert_fun(A_cholesky_upper, cholesky_upper) - cholesky_obs = _transpose(r_obs) - cholesky_cor = _transpose(r_cor) - - # Gather terms and return - mean_observed = (A @ rv.mean) + bias - m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0] - corrected = _normal.Normal(m_cor, cholesky_cor) - observed = _normal.Normal(mean_observed, cholesky_obs) - return observed, cond_util.Conditional(gain, corrected) - - -def _transpose(arr, /): - return np.transpose(arr, axes=(0, 2, 1)) diff --git a/probdiffeq/impl/blockdiag/_variable.py b/probdiffeq/impl/blockdiag/_variable.py deleted file mode 100644 index 063df1c7..00000000 --- a/probdiffeq/impl/blockdiag/_variable.py +++ /dev/null @@ -1,26 +0,0 @@ -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _variable -from probdiffeq.impl.blockdiag import _normal - - -class VariableBackend(_variable.VariableBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def rescale_cholesky(self, rv, factor, /): - cholesky = factor[..., None, None] * rv.cholesky - return _normal.Normal(rv.mean, cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + (rv.cholesky @ unit_sample[..., None])[..., 0] - - def to_multivariate_normal(self, rv): - mean = np.reshape(rv.mean.T, (-1,), order="F") - cov = np.block_diag(self._cov_dense(rv.cholesky)) - return (mean, cov) - - def _cov_dense(self, cholesky): - if cholesky.ndim > 2: - return functools.vmap(self._cov_dense)(cholesky) - return cholesky @ cholesky.T diff --git a/probdiffeq/impl/blockdiag/factorised_impl.py b/probdiffeq/impl/blockdiag/factorised_impl.py deleted file mode 100644 index 3b303b87..00000000 --- a/probdiffeq/impl/blockdiag/factorised_impl.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Factorisations.""" - -from probdiffeq.impl import _impl -from probdiffeq.impl.blockdiag import ( - _conditional, - _hidden_model, - _linearise, - _prototypes, - _ssm_util, - _stats, - _transform, - _variable, -) - - -class BlockDiag(_impl.FactorisedImpl): - """Block-diagonal factorisation.""" - - def __init__(self, ode_shape): - """Construct a block-diagonal factorisation.""" - self.ode_shape = ode_shape - - def conditional(self): - return _conditional.ConditionalBackend() - - def linearise(self): - return _linearise.LinearisationBackend() - - def hidden_model(self): - return _hidden_model.HiddenModelBackend(ode_shape=self.ode_shape) - - def stats(self): - return _stats.StatsBackend(ode_shape=self.ode_shape) - - def variable(self): - return _variable.VariableBackend(ode_shape=self.ode_shape) - - def ssm_util(self): - return _ssm_util.SSMUtilBackend(ode_shape=self.ode_shape) - - def transform(self): - return _transform.TransformBackend(ode_shape=self.ode_shape) - - def prototypes(self): - return _prototypes.PrototypeBackend(ode_shape=self.ode_shape) diff --git a/probdiffeq/impl/dense/__init__.py b/probdiffeq/impl/dense/__init__.py deleted file mode 100644 index df4f75ef..00000000 --- a/probdiffeq/impl/dense/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Dense state-space models (no factorisation).""" diff --git a/probdiffeq/impl/dense/_conditional.py b/probdiffeq/impl/dense/_conditional.py deleted file mode 100644 index d8a1dae0..00000000 --- a/probdiffeq/impl/dense/_conditional.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Conditional implementation.""" - -from probdiffeq.impl import _conditional -from probdiffeq.impl.dense import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class ConditionalBackend(_conditional.ConditionalBackend): - def apply(self, x, conditional, /): - matrix, noise = conditional - return _normal.Normal(matrix @ x + noise.mean, noise.cholesky) - - def marginalise(self, rv, conditional, /): - matmul, noise = conditional - R_stack = ((matmul @ rv.cholesky).T, noise.cholesky.T) - cholesky_new = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T - return _normal.Normal(matmul @ rv.mean + noise.mean, cholesky_new) - - def merge(self, cond1, cond2, /): - A, b = cond1 - C, d = cond2 - - g = A @ C - xi = A @ d.mean + b.mean - Xi = cholesky_util.sum_of_sqrtm_factors( - R_stack=((A @ d.cholesky).T, b.cholesky.T) - ) - return cond_util.Conditional(g, _normal.Normal(xi, Xi.T)) - - def revert(self, rv, conditional, /): - matrix, noise = conditional - mean, cholesky = rv.mean, rv.cholesky - - # QR-decomposition - # (todo: rename revert_conditional_noisefree to - # revert_transformation_cov_sqrt()) - r_obs, (r_cor, gain) = cholesky_util.revert_conditional( - R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T - ) - - # Gather terms and return - mean_observed = matrix @ mean + noise.mean - m_cor = mean - gain @ mean_observed - corrected = _normal.Normal(m_cor, r_cor.T) - observed = _normal.Normal(mean_observed, r_obs.T) - return observed, cond_util.Conditional(gain, corrected) diff --git a/probdiffeq/impl/dense/_hidden_model.py b/probdiffeq/impl/dense/_hidden_model.py deleted file mode 100644 index 92536993..00000000 --- a/probdiffeq/impl/dense/_hidden_model.py +++ /dev/null @@ -1,57 +0,0 @@ -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _hidden_model -from probdiffeq.impl.dense import _normal -from probdiffeq.util import cholesky_util, cond_util, linop_util - - -class HiddenModelBackend(_hidden_model.HiddenModelBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def qoi(self, rv): - if np.ndim(rv.mean) > 1: - return functools.vmap(self.qoi)(rv) - mean_reshaped = np.reshape(rv.mean, (-1, *self.ode_shape), order="F") - return mean_reshaped[0] - - def marginal_nth_derivative(self, rv, i): - if rv.mean.ndim > 1: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - m = self._select(rv.mean, i) - c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i) - c = cholesky_util.triu_via_qr(c.T) - return _normal.Normal(m, c.T) - - def qoi_from_sample(self, sample, /): - sample_reshaped = np.reshape(sample, (-1, *self.ode_shape), order="F") - return sample_reshaped[0] - - # TODO: move to linearise.py? - def conditional_to_derivative(self, i, standard_deviation): - a0 = functools.partial(self._select, idx_or_slice=i) - - (d,) = self.ode_shape - bias = np.zeros((d,)) - eye = np.eye(d) - noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: _autobatch_linop(a0)(s)) - return cond_util.Conditional(linop, noise) - - def _select(self, x, /, idx_or_slice): - x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F") - if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]: - raise ValueError - return x_reshaped[idx_or_slice] - - -def _autobatch_linop(fun): - def fun_(x): - if np.ndim(x) > 1: - return functools.vmap(fun_, in_axes=1, out_axes=1)(x) - return fun(x) - - return fun_ diff --git a/probdiffeq/impl/dense/_linearise.py b/probdiffeq/impl/dense/_linearise.py deleted file mode 100644 index d16495f8..00000000 --- a/probdiffeq/impl/dense/_linearise.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Linearisation.""" - -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _linearise -from probdiffeq.impl.dense import _normal -from probdiffeq.util import cholesky_util, linop_util - - -class LinearisationBackend(_linearise.LinearisationBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def ode_taylor_0th(self, ode_order): - def linearise_fun_wrapped(fun, mean): - a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order)) - a1 = functools.partial(self._select_dy, idx_or_slice=ode_order) - - if np.shape(a0(mean)) != (expected_shape := (ode_order, *self.ode_shape)): - msg = f"{np.shape(a0(mean))} != {expected_shape}" - raise ValueError(msg) - - fx = ts0(fun, a0(mean)) - linop = linop_util.parametrised_linop(lambda v, _p: _autobatch_linop(a1)(v)) - return linop, -fx - - return linearise_fun_wrapped - - def ode_taylor_1st(self, ode_order): - def new(fun, mean, /): - a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order)) - a1 = functools.partial(self._select_dy, idx_or_slice=ode_order) - - if np.shape(a0(mean)) != (expected_shape := (ode_order, *self.ode_shape)): - msg = f"{np.shape(a0(mean))} != {expected_shape}" - raise ValueError(msg) - - jvp, fx = ts1(fun, a0(mean)) - - @_autobatch_linop - def A(x): - x1 = a1(x) - x0 = a0(x) - return x1 - jvp(x0) - - linop = linop_util.parametrised_linop(lambda v, _p: A(v)) - return linop, -fx - - return new - - def ode_statistical_1st(self, cubature_fun): - cubature_rule = cubature_fun(input_shape=self.ode_shape) - linearise_fun = functools.partial(slr1, cubature_rule=cubature_rule) - - def new(fun, rv, /): - # Projection functions - a0 = _autobatch_linop(functools.partial(self._select_dy, idx_or_slice=0)) - a1 = _autobatch_linop(functools.partial(self._select_dy, idx_or_slice=1)) - - # Extract the linearisation point - m0, r_0_nonsquare = a0(rv.mean), a0(rv.cholesky) - r_0_square = cholesky_util.triu_via_qr(r_0_nonsquare.T) - linearisation_pt = _normal.Normal(m0, r_0_square.T) - - # Gather the variables and return - J, noise = linearise_fun(fun, linearisation_pt) - - def A(x): - return a1(x) - J @ a0(x) - - linop = linop_util.parametrised_linop(lambda v, _p: A(v)) - - mean, cov_lower = noise.mean, noise.cholesky - bias = _normal.Normal(-mean, cov_lower) - return linop, bias - - return new - - def ode_statistical_0th(self, cubature_fun): - cubature_rule = cubature_fun(input_shape=self.ode_shape) - linearise_fun = functools.partial(slr0, cubature_rule=cubature_rule) - - def new(fun, rv, /): - # Projection functions - a0 = _autobatch_linop(functools.partial(self._select_dy, idx_or_slice=0)) - a1 = _autobatch_linop(functools.partial(self._select_dy, idx_or_slice=1)) - - # Extract the linearisation point - m0, r_0_nonsquare = a0(rv.mean), a0(rv.cholesky) - r_0_square = cholesky_util.triu_via_qr(r_0_nonsquare.T) - linearisation_pt = _normal.Normal(m0, r_0_square.T) - - # Gather the variables and return - noise = linearise_fun(fun, linearisation_pt) - mean, cov_lower = noise.mean, noise.cholesky - bias = _normal.Normal(-mean, cov_lower) - linop = linop_util.parametrised_linop(lambda v, _p: a1(v)) - return linop, bias - - return new - - def _select_dy(self, x, idx_or_slice): - (d,) = self.ode_shape - x_reshaped = np.reshape(x, (-1, d), order="F") - return x_reshaped[idx_or_slice, ...] - - -def _autobatch_linop(fun): - def fun_(x): - if np.ndim(x) > 1: - return functools.vmap(fun_, in_axes=1, out_axes=1)(x) - return fun(x) - - return fun_ - - -def ts0(fn, m): - return fn(m) - - -def ts1(fn, m): - b, jvp = functools.linearize(fn, m) - return jvp, b - jvp(m) - - -def slr1(fn, x, *, cubature_rule): - """Linearise a function with first-order statistical linear regression.""" - # Create sigma-points - pts_centered = cubature_rule.points @ x.cholesky.T - pts = x.mean[None, :] + pts_centered - pts_centered_normed = pts_centered * cubature_rule.weights_sqrtm[:, None] - - # Evaluate the nonlinear function - fx = functools.vmap(fn)(pts) - fx_mean = cubature_rule.weights_sqrtm**2 @ fx - fx_centered = fx - fx_mean[None, :] - fx_centered_normed = fx_centered * cubature_rule.weights_sqrtm[:, None] - - # Compute statistical linear regression matrices - _, (cov_sqrtm_cond, linop_cond) = cholesky_util.revert_conditional_noisefree( - R_X_F=pts_centered_normed, R_X=fx_centered_normed - ) - mean_cond = fx_mean - linop_cond @ x.mean - rv_cond = _normal.Normal(mean_cond, cov_sqrtm_cond.T) - return linop_cond, rv_cond - - -def slr0(fn, x, *, cubature_rule): - """Linearise a function with zeroth-order statistical linear regression. - - !!! warning "Warning: highly EXPERIMENTAL feature!" - This feature is highly experimental. - There is no guarantee that it works correctly. - It might be deleted tomorrow - and without any deprecation policy. - - """ - # Create sigma-points - pts_centered = cubature_rule.points @ x.cholesky.T - pts = x.mean[None, :] + pts_centered - - # Evaluate the nonlinear function - fx = functools.vmap(fn)(pts) - fx_mean = cubature_rule.weights_sqrtm**2 @ fx - fx_centered = fx - fx_mean[None, :] - fx_centered_normed = fx_centered * cubature_rule.weights_sqrtm[:, None] - - cov_sqrtm = cholesky_util.triu_via_qr(fx_centered_normed) - - return _normal.Normal(fx_mean, cov_sqrtm.T) diff --git a/probdiffeq/impl/dense/_prototypes.py b/probdiffeq/impl/dense/_prototypes.py deleted file mode 100644 index f1361553..00000000 --- a/probdiffeq/impl/dense/_prototypes.py +++ /dev/null @@ -1,22 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _prototypes -from probdiffeq.impl.dense import _normal - - -class PrototypeBackend(_prototypes.PrototypeBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def qoi(self): - return np.empty(self.ode_shape) - - def observed(self): - mean = np.empty(self.ode_shape) - cholesky = np.empty(self.ode_shape + self.ode_shape) - return _normal.Normal(mean, cholesky) - - def error_estimate(self): - return np.empty(self.ode_shape) - - def output_scale(self): - return np.empty(()) diff --git a/probdiffeq/impl/dense/_ssm_util.py b/probdiffeq/impl/dense/_ssm_util.py deleted file mode 100644 index fcccc941..00000000 --- a/probdiffeq/impl/dense/_ssm_util.py +++ /dev/null @@ -1,83 +0,0 @@ -"""State-space model utilities.""" - -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _ssm_util -from probdiffeq.impl.dense import _normal -from probdiffeq.util import cholesky_util, cond_util, ibm_util - - -class SSMUtilBackend(_ssm_util.SSMUtilBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def ibm_transitions(self, num_derivatives, output_scale): - a, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale) - (d,) = self.ode_shape - eye_d = np.eye(d) - A = np.kron(eye_d, a) - Q = np.kron(eye_d, q_sqrtm) - - ndim = d * (num_derivatives + 1) - q0 = np.zeros((ndim,)) - noise = _normal.Normal(q0, Q) - - precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) - - def discretise(dt): - p, p_inv = precon_fun(dt) - p = np.tile(p, d) - p_inv = np.tile(p_inv, d) - return cond_util.Conditional(A, noise), (p, p_inv) - - return discretise - - def identity_conditional(self, ndim, /): - (d,) = self.ode_shape - n = ndim * d - - A = np.eye(n) - m = np.zeros((n,)) - C = np.zeros((n, n)) - return cond_util.Conditional(A, _normal.Normal(m, C)) - - def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): - if len(tcoeffs) != num_derivatives + 1: - msg1 = f"The number of Taylor coefficients {len(tcoeffs)} does not match " - msg2 = f"the number of derivatives {num_derivatives+1} in the solver." - raise ValueError(msg1 + msg2) - - if tcoeffs[0].shape != self.ode_shape: - msg = "The solver's ODE dimension does not match the initial condition." - raise ValueError(msg) - - m0_matrix = np.stack(tcoeffs) - m0_corrected = np.reshape(m0_matrix, (-1,), order="F") - - (ode_dim,) = self.ode_shape - ndim = (num_derivatives + 1) * ode_dim - c_sqrtm0_corrected = np.zeros((ndim, ndim)) - - return _normal.Normal(m0_corrected, c_sqrtm0_corrected) - - def preconditioner_apply(self, rv, p, /): - mean = p * rv.mean - cholesky = p[:, None] * rv.cholesky - return _normal.Normal(mean, cholesky) - - def preconditioner_apply_cond(self, cond, p, p_inv, /): - A, noise = cond - noise = self.preconditioner_apply(noise, p) - A = p[:, None] * A * p_inv[None, :] - return cond_util.Conditional(A, noise) - - def standard_normal(self, ndim, /, output_scale): - eye_n = np.eye(ndim) - eye_d = output_scale * np.eye(*self.ode_shape) - cholesky = np.kron(eye_d, eye_n) - mean = np.zeros((*self.ode_shape, ndim)).reshape((-1,), order="F") - return _normal.Normal(mean, cholesky) - - def update_mean(self, mean, x, /, num): - return cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) / np.sqrt( - num + 1 - ) diff --git a/probdiffeq/impl/dense/_stats.py b/probdiffeq/impl/dense/_stats.py deleted file mode 100644 index ec825d2f..00000000 --- a/probdiffeq/impl/dense/_stats.py +++ /dev/null @@ -1,40 +0,0 @@ -from probdiffeq.backend import functools, linalg -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _stats - - -class StatsBackend(_stats.StatsBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def mahalanobis_norm_relative(self, u, /, rv): - residual_white = linalg.solve_triangular( - rv.cholesky.T, u - rv.mean, lower=False, trans="T" - ) - mahalanobis = linalg.qr_r(residual_white[:, None]) - return np.reshape(np.abs(mahalanobis) / np.sqrt(rv.mean.size), ()) - - def logpdf(self, u, /, rv): - # The cholesky factor is triangular, so we compute a cheap slogdet. - diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2) - slogdet = np.sum(np.log(np.abs(diagonal))) - - dx = u - rv.mean - residual_white = linalg.solve_triangular(rv.cholesky.T, dx, trans="T") - x1 = linalg.vector_dot(residual_white, residual_white) - x2 = 2.0 * slogdet - x3 = u.size * np.log(np.pi() * 2) - return -0.5 * (x1 + x2 + x3) - - def mean(self, rv): - return rv.mean - - def standard_deviation(self, rv): - if rv.mean.ndim > 1: - return functools.vmap(self.standard_deviation)(rv) - - diag = np.einsum("ij,ij->i", rv.cholesky, rv.cholesky) - return np.sqrt(diag) - - def sample_shape(self, rv): - return rv.mean.shape diff --git a/probdiffeq/impl/dense/_transform.py b/probdiffeq/impl/dense/_transform.py deleted file mode 100644 index 4d248352..00000000 --- a/probdiffeq/impl/dense/_transform.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Random variable transformations.""" - -from probdiffeq.backend import containers -from probdiffeq.backend.typing import Array, Callable -from probdiffeq.impl import _transform -from probdiffeq.impl.dense import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class Transformation(containers.NamedTuple): - matmul: Callable - bias: Array - - -class TransformBackend(_transform.TransformBackend): - def marginalise(self, rv, transformation, /): - A, b = transformation - cholesky_new = cholesky_util.triu_via_qr((A @ rv.cholesky).T).T - return _normal.Normal(A @ rv.mean + b, cholesky_new) - - def revert(self, rv, transformation, /): - A, b = transformation - mean, cholesky = rv.mean, rv.cholesky - - # QR-decomposition - # (todo: rename revert_conditional_noisefree to - # revert_transformation_cov_sqrt()) - r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree( - R_X_F=(A @ cholesky).T, R_X=cholesky.T - ) - - # Gather terms and return - m_cor = mean - gain @ (A @ mean + b) - corrected = _normal.Normal(m_cor, r_cor.T) - observed = _normal.Normal(A @ mean + b, r_obs.T) - return observed, cond_util.Conditional(gain, corrected) diff --git a/probdiffeq/impl/dense/_variable.py b/probdiffeq/impl/dense/_variable.py deleted file mode 100644 index 82ddf57c..00000000 --- a/probdiffeq/impl/dense/_variable.py +++ /dev/null @@ -1,17 +0,0 @@ -from probdiffeq.impl import _variable -from probdiffeq.impl.dense import _normal - - -class VariableBackend(_variable.VariableBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def rescale_cholesky(self, rv, factor, /): - cholesky = factor[..., None, None] * rv.cholesky - return _normal.Normal(rv.mean, cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + rv.cholesky @ unit_sample - - def to_multivariate_normal(self, rv): - return rv.mean, rv.cholesky @ rv.cholesky.T diff --git a/probdiffeq/impl/dense/factorised_impl.py b/probdiffeq/impl/dense/factorised_impl.py deleted file mode 100644 index 7780f274..00000000 --- a/probdiffeq/impl/dense/factorised_impl.py +++ /dev/null @@ -1,46 +0,0 @@ -"""API for dense factorisations.""" - -from probdiffeq.impl import _impl -from probdiffeq.impl.dense import ( - _conditional, - _hidden_model, - _linearise, - _prototypes, - _ssm_util, - _stats, - _transform, - _variable, -) - - -class Dense(_impl.FactorisedImpl): - """Dense factorisation.""" - - def __init__(self, ode_shape): - """Construct a dense factorisation.""" - # TODO: add "order="F"" key - self.ode_shape = ode_shape - - def linearise(self): - return _linearise.LinearisationBackend(ode_shape=self.ode_shape) - - def hidden_model(self): - return _hidden_model.HiddenModelBackend(ode_shape=self.ode_shape) - - def stats(self): - return _stats.StatsBackend(ode_shape=self.ode_shape) - - def variable(self): - return _variable.VariableBackend(ode_shape=self.ode_shape) - - def conditional(self): - return _conditional.ConditionalBackend() - - def transform(self): - return _transform.TransformBackend() - - def ssm_util(self): - return _ssm_util.SSMUtilBackend(ode_shape=self.ode_shape) - - def prototypes(self): - return _prototypes.PrototypeBackend(ode_shape=self.ode_shape) diff --git a/probdiffeq/impl/isotropic/__init__.py b/probdiffeq/impl/isotropic/__init__.py deleted file mode 100644 index db9a1b6b..00000000 --- a/probdiffeq/impl/isotropic/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Isotropic state-space models.""" diff --git a/probdiffeq/impl/isotropic/_conditional.py b/probdiffeq/impl/isotropic/_conditional.py deleted file mode 100644 index 8603b0a6..00000000 --- a/probdiffeq/impl/isotropic/_conditional.py +++ /dev/null @@ -1,51 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _conditional -from probdiffeq.impl.isotropic import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class ConditionalBackend(_conditional.ConditionalBackend): - def apply(self, x, conditional, /): - A, noise = conditional - # if the gain is qoi-to-hidden, the data is a (d,) array. - # this is problematic for the isotropic model unless we explicitly broadcast. - if np.ndim(x) == 1: - x = x[None, :] - return _normal.Normal(A @ x + noise.mean, noise.cholesky) - - def marginalise(self, rv, conditional, /): - matrix, noise = conditional - - mean = matrix @ rv.mean + noise.mean - - R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T) - cholesky = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T - return _normal.Normal(mean, cholesky) - - def merge(self, cond1, cond2, /): - A, b = cond1 - C, d = cond2 - - g = A @ C - xi = A @ d.mean + b.mean - R_stack = ((A @ d.cholesky).T, b.cholesky.T) - Xi = cholesky_util.sum_of_sqrtm_factors(R_stack).T - - noise = _normal.Normal(xi, Xi) - return cond_util.Conditional(g, noise) - - def revert(self, rv, conditional, /): - matrix, noise = conditional - - r_ext_p, (r_bw_p, gain) = cholesky_util.revert_conditional( - R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T - ) - extrapolated_cholesky = r_ext_p.T - corrected_cholesky = r_bw_p.T - - extrapolated_mean = matrix @ rv.mean + noise.mean - corrected_mean = rv.mean - gain @ extrapolated_mean - - extrapolated = _normal.Normal(extrapolated_mean, extrapolated_cholesky) - corrected = _normal.Normal(corrected_mean, corrected_cholesky) - return extrapolated, cond_util.Conditional(gain, corrected) diff --git a/probdiffeq/impl/isotropic/_hidden_model.py b/probdiffeq/impl/isotropic/_hidden_model.py deleted file mode 100644 index 9ff880a6..00000000 --- a/probdiffeq/impl/isotropic/_hidden_model.py +++ /dev/null @@ -1,39 +0,0 @@ -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _hidden_model -from probdiffeq.impl.isotropic import _normal -from probdiffeq.util import cholesky_util, cond_util, linop_util - - -class HiddenModelBackend(_hidden_model.HiddenModelBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def qoi(self, rv): - return rv.mean[..., 0, :] - - def marginal_nth_derivative(self, rv, i): - if np.ndim(rv.mean) > 2: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - if i > np.shape(rv.mean)[0]: - raise ValueError - - mean = rv.mean[i, :] - cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T - return _normal.Normal(mean, cholesky) - - def qoi_from_sample(self, sample, /): - return sample[0, :] - - def conditional_to_derivative(self, i, standard_deviation): - def A(x): - return x[[i], ...] - - bias = np.zeros(self.ode_shape) - eye = np.eye(1) - noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: A(s)) - return cond_util.Conditional(linop, noise) diff --git a/probdiffeq/impl/isotropic/_linearise.py b/probdiffeq/impl/isotropic/_linearise.py deleted file mode 100644 index 4200cb54..00000000 --- a/probdiffeq/impl/isotropic/_linearise.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Linearisation.""" - -from probdiffeq.impl import _linearise -from probdiffeq.util import linop_util - - -class LinearisationBackend(_linearise.LinearisationBackend): - def ode_taylor_1st(self, ode_order): - raise NotImplementedError - - def ode_taylor_0th(self, ode_order): - def linearise_fun_wrapped(fun, mean): - fx = ts0(fun, mean[:ode_order, ...]) - linop = linop_util.parametrised_linop(lambda s, _p: s[[ode_order], ...]) - return linop, -fx - - return linearise_fun_wrapped - - def ode_statistical_0th(self, cubature_fun): - raise NotImplementedError - - def ode_statistical_1st(self, cubature_fun): - raise NotImplementedError - - -def ts0(fn, m): - return fn(m) diff --git a/probdiffeq/impl/isotropic/_normal.py b/probdiffeq/impl/isotropic/_normal.py deleted file mode 100644 index 555e2541..00000000 --- a/probdiffeq/impl/isotropic/_normal.py +++ /dev/null @@ -1,7 +0,0 @@ -from probdiffeq.backend import containers -from probdiffeq.backend.typing import Array - - -class Normal(containers.NamedTuple): - mean: Array - cholesky: Array diff --git a/probdiffeq/impl/isotropic/_prototypes.py b/probdiffeq/impl/isotropic/_prototypes.py deleted file mode 100644 index 09f3bd98..00000000 --- a/probdiffeq/impl/isotropic/_prototypes.py +++ /dev/null @@ -1,22 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _prototypes -from probdiffeq.impl.isotropic import _normal - - -class PrototypeBackend(_prototypes.PrototypeBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def qoi(self): - return np.empty(self.ode_shape) - - def observed(self): - mean = np.empty((1, *self.ode_shape)) - cholesky = np.empty(()) - return _normal.Normal(mean, cholesky) - - def error_estimate(self): - return np.empty(()) - - def output_scale(self): - return np.empty(()) diff --git a/probdiffeq/impl/isotropic/_ssm_util.py b/probdiffeq/impl/isotropic/_ssm_util.py deleted file mode 100644 index 54258c7a..00000000 --- a/probdiffeq/impl/isotropic/_ssm_util.py +++ /dev/null @@ -1,60 +0,0 @@ -"""State-space model utilities.""" - -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _ssm_util -from probdiffeq.impl.isotropic import _normal -from probdiffeq.util import cholesky_util, cond_util, ibm_util - - -class SSMUtilBackend(_ssm_util.SSMUtilBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def ibm_transitions(self, num_derivatives, output_scale): - A, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale) - q0 = np.zeros((num_derivatives + 1, *self.ode_shape)) - noise = _normal.Normal(q0, q_sqrtm) - precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) - - def discretise(dt): - p, p_inv = precon_fun(dt) - return cond_util.Conditional(A, noise), (p, p_inv) - - return discretise - - def identity_conditional(self, num_hidden_states_per_ode_dim, /): - m0 = np.zeros((num_hidden_states_per_ode_dim, *self.ode_shape)) - c0 = np.zeros((num_hidden_states_per_ode_dim, num_hidden_states_per_ode_dim)) - noise = _normal.Normal(m0, c0) - matrix = np.eye(num_hidden_states_per_ode_dim) - return cond_util.Conditional(matrix, noise) - - def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): - if len(tcoeffs) != num_derivatives + 1: - msg1 = f"The number of Taylor coefficients {len(tcoeffs)} does not match " - msg2 = f"the number of derivatives {num_derivatives+1} in the solver." - raise ValueError(msg1 + msg2) - - c_sqrtm0_corrected = np.zeros((num_derivatives + 1, num_derivatives + 1)) - m0_corrected = np.stack(tcoeffs) - return _normal.Normal(m0_corrected, c_sqrtm0_corrected) - - def preconditioner_apply(self, rv, p, /): - return _normal.Normal(p[:, None] * rv.mean, p[:, None] * rv.cholesky) - - def preconditioner_apply_cond(self, cond, p, p_inv, /): - A, noise = cond - - A_new = p[:, None] * A * p_inv[None, :] - - noise = _normal.Normal(p[:, None] * noise.mean, p[:, None] * noise.cholesky) - return cond_util.Conditional(A_new, noise) - - def standard_normal(self, num, /, output_scale): - mean = np.zeros((num, *self.ode_shape)) - cholesky = output_scale * np.eye(num) - return _normal.Normal(mean, cholesky) - - def update_mean(self, mean, x, /, num): - sum_updated = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) - return sum_updated / np.sqrt(num + 1) diff --git a/probdiffeq/impl/isotropic/_stats.py b/probdiffeq/impl/isotropic/_stats.py deleted file mode 100644 index 9194aab1..00000000 --- a/probdiffeq/impl/isotropic/_stats.py +++ /dev/null @@ -1,46 +0,0 @@ -from probdiffeq.backend import functools, linalg -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _stats -from probdiffeq.impl.isotropic import _normal - - -class StatsBackend(_stats.StatsBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def mahalanobis_norm_relative(self, u, /, rv): - residual_white = (rv.mean - u) / rv.cholesky - residual_white_matrix = linalg.qr_r(residual_white.T) - return np.reshape(np.abs(residual_white_matrix) / np.sqrt(rv.mean.size), ()) - - def logpdf(self, u, /, rv): - # if the gain is qoi-to-hidden, the data is a (d,) array. - # this is problematic for the isotropic model unless we explicitly broadcast. - if np.ndim(u) == 1: - u = u[None, :] - - def logpdf_scalar(x, r): - dx = x - r.mean - w = linalg.solve_triangular(r.cholesky.T, dx, trans="T") - - maha_term = linalg.vector_dot(w, w) - - diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2) - slogdet = np.sum(np.log(np.abs(diagonal))) - logdet_term = 2.0 * slogdet - return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2)) - - # Batch in the "mean" dimension and sum the results. - rv_batch = _normal.Normal(1, None) - return np.sum(functools.vmap(logpdf_scalar, in_axes=(1, rv_batch))(u, rv)) - - def mean(self, rv): - return rv.mean - - def standard_deviation(self, rv): - if rv.cholesky.ndim > 1: - return functools.vmap(self.standard_deviation)(rv) - return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) - - def sample_shape(self, rv): - return rv.mean.shape diff --git a/probdiffeq/impl/isotropic/_transform.py b/probdiffeq/impl/isotropic/_transform.py deleted file mode 100644 index 29a54a8a..00000000 --- a/probdiffeq/impl/isotropic/_transform.py +++ /dev/null @@ -1,33 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _transform -from probdiffeq.impl.isotropic import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class TransformBackend(_transform.TransformBackend): - def marginalise(self, rv, transformation, /): - A, b = transformation - mean, cholesky = rv.mean, rv.cholesky - cholesky_new = cholesky_util.triu_via_qr((A @ cholesky).T) - cholesky_squeezed = np.reshape(cholesky_new, ()) - return _normal.Normal((A @ mean) + b, cholesky_squeezed) - - def revert(self, rv, transformation, /): - A, b = transformation - mean, cholesky = rv.mean, rv.cholesky - - # QR-decomposition - # (todo: rename revert_conditional_noisefree - # to revert_transformation_cov_sqrt()) - r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree( - R_X_F=(A @ cholesky).T, R_X=cholesky.T - ) - cholesky_obs = np.reshape(r_obs, ()) - cholesky_cor = r_cor.T - - # Gather terms and return - mean_observed = A @ mean + b - m_cor = mean - gain * mean_observed - corrected = _normal.Normal(m_cor, cholesky_cor) - observed = _normal.Normal(mean_observed, cholesky_obs) - return observed, cond_util.Conditional(gain, corrected) diff --git a/probdiffeq/impl/isotropic/_variable.py b/probdiffeq/impl/isotropic/_variable.py deleted file mode 100644 index 67ab0f27..00000000 --- a/probdiffeq/impl/isotropic/_variable.py +++ /dev/null @@ -1,22 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _variable -from probdiffeq.impl.isotropic import _normal - - -class VariableBackend(_variable.VariableBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def rescale_cholesky(self, rv, factor, /): - cholesky = factor[..., None, None] * rv.cholesky - return _normal.Normal(rv.mean, cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + rv.cholesky @ unit_sample - - def to_multivariate_normal(self, rv): - eye_d = np.eye(*self.ode_shape) - cov = rv.cholesky @ rv.cholesky.T - cov = np.kron(eye_d, cov) - mean = rv.mean.reshape((-1,), order="F") - return (mean, cov) diff --git a/probdiffeq/impl/isotropic/factorised_impl.py b/probdiffeq/impl/isotropic/factorised_impl.py deleted file mode 100644 index 9366629f..00000000 --- a/probdiffeq/impl/isotropic/factorised_impl.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Isotropic factorisation.""" - -from probdiffeq.impl import _impl -from probdiffeq.impl.isotropic import ( - _conditional, - _hidden_model, - _linearise, - _prototypes, - _ssm_util, - _stats, - _transform, - _variable, -) - - -class Isotropic(_impl.FactorisedImpl): - """Isotropic factorisation.""" - - def __init__(self, ode_shape): - """Construct an isotropic factorisation.""" - self.ode_shape = ode_shape - - def conditional(self): - return _conditional.ConditionalBackend() - - def linearise(self): - return _linearise.LinearisationBackend() - - def hidden_model(self): - return _hidden_model.HiddenModelBackend(ode_shape=self.ode_shape) - - def stats(self): - return _stats.StatsBackend(ode_shape=self.ode_shape) - - def variable(self) -> _variable.VariableBackend: - return _variable.VariableBackend(ode_shape=self.ode_shape) - - def ssm_util(self): - return _ssm_util.SSMUtilBackend(ode_shape=self.ode_shape) - - def transform(self): - return _transform.TransformBackend() - - def prototypes(self): - return _prototypes.PrototypeBackend(ode_shape=self.ode_shape) diff --git a/probdiffeq/impl/scalar/__init__.py b/probdiffeq/impl/scalar/__init__.py deleted file mode 100644 index baa7adcb..00000000 --- a/probdiffeq/impl/scalar/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Scalar factorisation.""" diff --git a/probdiffeq/impl/scalar/_conditional.py b/probdiffeq/impl/scalar/_conditional.py deleted file mode 100644 index e659f96f..00000000 --- a/probdiffeq/impl/scalar/_conditional.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Conditionals.""" - -from probdiffeq.backend import linalg -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _conditional -from probdiffeq.impl.scalar import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class ConditionalBackend(_conditional.ConditionalBackend): - def marginalise(self, rv, conditional, /): - matrix, noise = conditional - - mean = matrix @ rv.mean + noise.mean - R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T) - cholesky_T = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack) - return _normal.Normal(mean, cholesky_T.T) - - def revert(self, rv, conditional, /): - matrix, noise = conditional - - r_ext, (r_bw_p, g_bw_p) = cholesky_util.revert_conditional( - R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T - ) - m_ext = matrix @ rv.mean + noise.mean - m_cond = rv.mean - g_bw_p @ m_ext - - marginal = _normal.Normal(m_ext, r_ext.T) - noise = _normal.Normal(m_cond, r_bw_p.T) - return marginal, cond_util.Conditional(g_bw_p, noise) - - def apply(self, x, conditional, /): - matrix, noise = conditional - matrix = np.squeeze(matrix) - return _normal.Normal(linalg.vector_dot(matrix, x) + noise.mean, noise.cholesky) - - def merge(self, previous, incoming, /): - A, b = previous - C, d = incoming - - g = A @ C - xi = A @ d.mean + b.mean - R_stack = ((A @ d.cholesky).T, b.cholesky.T) - Xi = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T - - noise = _normal.Normal(xi, Xi) - return cond_util.Conditional(g, noise) diff --git a/probdiffeq/impl/scalar/_hidden_model.py b/probdiffeq/impl/scalar/_hidden_model.py deleted file mode 100644 index 3a1c4c0b..00000000 --- a/probdiffeq/impl/scalar/_hidden_model.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Hidden state-space model implementation.""" - -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _hidden_model -from probdiffeq.impl.scalar import _normal -from probdiffeq.util import cholesky_util, cond_util, linop_util - - -class HiddenModelBackend(_hidden_model.HiddenModelBackend): - def qoi(self, rv): - return rv.mean[..., 0] - - def marginal_nth_derivative(self, rv, i): - if rv.mean.ndim > 1: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - if i > rv.mean.shape[0]: - raise ValueError - - m = rv.mean[i] - c = rv.cholesky[[i], :] - chol = cholesky_util.triu_via_qr(c.T) - return _normal.Normal(np.reshape(m, ()), np.reshape(chol, ())) - - def qoi_from_sample(self, sample, /): - return sample[0] - - def conditional_to_derivative(self, i, standard_deviation): - def A(x): - return x[[i], ...] - - bias = np.zeros(()) - eye = np.eye(1) - noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: A(s)) - return cond_util.Conditional(linop, noise) diff --git a/probdiffeq/impl/scalar/_linearise.py b/probdiffeq/impl/scalar/_linearise.py deleted file mode 100644 index 25226263..00000000 --- a/probdiffeq/impl/scalar/_linearise.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Linearisation.""" - -from probdiffeq.impl import _linearise - - -class LinearisationBackend(_linearise.LinearisationBackend): - def ode_taylor_0th(self, ode_order): - def linearise_fun_wrapped(fun, mean): - fx = ts0(fun, mean[:ode_order]) - return lambda s: s[ode_order], -fx - - return linearise_fun_wrapped - - def ode_taylor_1st(self, ode_order): - raise NotImplementedError - - def ode_statistical_1st(self, cubature_fun): - raise NotImplementedError - - def ode_statistical_0th(self, cubature_fun): - raise NotImplementedError - - -def ts0(fn, m): - return fn(m) diff --git a/probdiffeq/impl/scalar/_normal.py b/probdiffeq/impl/scalar/_normal.py deleted file mode 100644 index 00775cf1..00000000 --- a/probdiffeq/impl/scalar/_normal.py +++ /dev/null @@ -1,7 +0,0 @@ -from probdiffeq.backend import containers -from probdiffeq.backend.typing import Any - - -class Normal(containers.NamedTuple): - mean: Any - cholesky: Any diff --git a/probdiffeq/impl/scalar/_prototypes.py b/probdiffeq/impl/scalar/_prototypes.py deleted file mode 100644 index abcf3724..00000000 --- a/probdiffeq/impl/scalar/_prototypes.py +++ /dev/null @@ -1,19 +0,0 @@ -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _prototypes -from probdiffeq.impl.scalar import _normal - - -class PrototypeBackend(_prototypes.PrototypeBackend): - def qoi(self): - return np.empty(()) - - def observed(self): - mean = np.empty(()) - cholesky = np.empty(()) - return _normal.Normal(mean, cholesky) - - def error_estimate(self): - return np.empty(()) - - def output_scale(self): - return np.empty(()) diff --git a/probdiffeq/impl/scalar/_ssm_util.py b/probdiffeq/impl/scalar/_ssm_util.py deleted file mode 100644 index 1d19771d..00000000 --- a/probdiffeq/impl/scalar/_ssm_util.py +++ /dev/null @@ -1,56 +0,0 @@ -"""SSM utilities.""" - -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _ssm_util -from probdiffeq.impl.scalar import _normal -from probdiffeq.util import cholesky_util, cond_util, ibm_util - - -class SSMUtilBackend(_ssm_util.SSMUtilBackend): - def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives): - if len(tcoeffs) != num_derivatives + 1: - msg1 = "The number of Taylor coefficients does not match " - msg2 = "the number of derivatives in the implementation." - raise ValueError(msg1 + msg2) - m0_matrix = np.stack(tcoeffs) - m0_corrected = np.reshape(m0_matrix, (-1,), order="F") - c_sqrtm0_corrected = np.zeros((num_derivatives + 1, num_derivatives + 1)) - return _normal.Normal(m0_corrected, c_sqrtm0_corrected) - - def preconditioner_apply(self, rv, p, /): - return _normal.Normal(p * rv.mean, p[:, None] * rv.cholesky) - - def preconditioner_apply_cond(self, cond, p, p_inv, /): - A, noise = cond - A = p[:, None] * A * p_inv[None, :] - noise = _normal.Normal(p * noise.mean, p[:, None] * noise.cholesky) - return cond_util.Conditional(A, noise) - - def ibm_transitions(self, num_derivatives, output_scale=1.0): - a, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale) - q0 = np.zeros((num_derivatives + 1,)) - noise = _normal.Normal(q0, q_sqrtm) - - precon_fun = ibm_util.preconditioner_prepare(num_derivatives=num_derivatives) - - def discretise(dt): - p, p_inv = precon_fun(dt) - return cond_util.Conditional(a, noise), (p, p_inv) - - return discretise - - def identity_conditional(self, ndim, /): - transition = np.eye(ndim) - mean = np.zeros((ndim,)) - cov_sqrtm = np.zeros((ndim, ndim)) - noise = _normal.Normal(mean, cov_sqrtm) - return cond_util.Conditional(transition, noise) - - def standard_normal(self, ndim, /, output_scale): - mean = np.zeros((ndim,)) - cholesky = output_scale * np.eye(ndim) - return _normal.Normal(mean, cholesky) - - def update_mean(self, mean, x, /, num): - sum_updated = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) - return sum_updated / np.sqrt(num + 1) diff --git a/probdiffeq/impl/scalar/_stats.py b/probdiffeq/impl/scalar/_stats.py deleted file mode 100644 index ca4188da..00000000 --- a/probdiffeq/impl/scalar/_stats.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Random variable implementation.""" - -from probdiffeq.backend import functools, linalg -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _stats - - -class StatsBackend(_stats.StatsBackend): - def mahalanobis_norm_relative(self, u, /, rv): - res_white = (u - rv.mean) / rv.cholesky - return np.abs(res_white) / np.sqrt(rv.mean.size) - - def logpdf(self, u, /, rv): - dx = u - rv.mean - w = linalg.solve_triangular(rv.cholesky.T, dx, trans="T") - - maha_term = linalg.vector_dot(w, w) - - diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2) - slogdet = np.sum(np.log(np.abs(diagonal))) - logdet_term = 2.0 * slogdet - return -0.5 * (logdet_term + maha_term + u.size * np.log(np.pi() * 2)) - - def standard_deviation(self, rv): - if rv.cholesky.ndim > 1: - return functools.vmap(self.standard_deviation)(rv) - - return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) - - def mean(self, rv): - return rv.mean - - def sample_shape(self, rv): - return rv.mean.shape diff --git a/probdiffeq/impl/scalar/_transform.py b/probdiffeq/impl/scalar/_transform.py deleted file mode 100644 index 8d3fd2a7..00000000 --- a/probdiffeq/impl/scalar/_transform.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Random variable transformations.""" - -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _transform -from probdiffeq.impl.scalar import _normal -from probdiffeq.util import cholesky_util, cond_util - - -class TransformBackend(_transform.TransformBackend): - def marginalise(self, rv, transformation, /): - # currently, assumes that A(rv.cholesky) is a vector, not a matrix. - matmul, b = transformation - cholesky_new = cholesky_util.triu_via_qr(matmul(rv.cholesky)[:, None]) - cholesky_new_squeezed = np.reshape(cholesky_new, ()) - return _normal.Normal(matmul(rv.mean) + b, cholesky_new_squeezed) - - def revert(self, rv, transformation, /): - # Assumes that A maps a vector to a scalar... - - # Extract information - A, b = transformation - - # QR-decomposition - # (todo: rename revert_conditional_noisefree - # to transformation_revert_cov_sqrt()) - r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree( - R_X_F=A(rv.cholesky)[:, None], R_X=rv.cholesky.T - ) - cholesky_obs = np.reshape(r_obs, ()) - cholesky_cor = r_cor.T - gain = np.squeeze_along_axis(gain, axis=-1) - - # Gather terms and return - m_cor = rv.mean - gain * (A(rv.mean) + b) - corrected = _normal.Normal(m_cor, cholesky_cor) - observed = _normal.Normal(A(rv.mean) + b, cholesky_obs) - return observed, cond_util.Conditional(gain, corrected) diff --git a/probdiffeq/impl/scalar/_variable.py b/probdiffeq/impl/scalar/_variable.py deleted file mode 100644 index 27922ea5..00000000 --- a/probdiffeq/impl/scalar/_variable.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Random variable implementation.""" - -from probdiffeq.backend import functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _variable -from probdiffeq.impl.scalar import _normal - - -class VariableBackend(_variable.VariableBackend): - def rescale_cholesky(self, rv, factor): - if np.ndim(factor) > 0: - return functools.vmap(self.rescale_cholesky)(rv, factor) - return _normal.Normal(rv.mean, factor * rv.cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + rv.cholesky @ unit_sample - - def to_multivariate_normal(self, rv): - return rv.mean, rv.cholesky @ rv.cholesky.T diff --git a/probdiffeq/impl/scalar/factorised_impl.py b/probdiffeq/impl/scalar/factorised_impl.py deleted file mode 100644 index b136f17c..00000000 --- a/probdiffeq/impl/scalar/factorised_impl.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Scalar factorisation.""" - -from probdiffeq.impl import _impl -from probdiffeq.impl.scalar import ( - _conditional, - _hidden_model, - _linearise, - _prototypes, - _ssm_util, - _stats, - _transform, - _variable, -) - - -class Scalar(_impl.FactorisedImpl): - """Scalar factorisation.""" - - def linearise(self): - return _linearise.LinearisationBackend() - - def conditional(self): - return _conditional.ConditionalBackend() - - def transform(self): - return _transform.TransformBackend() - - def ssm_util(self): - return _ssm_util.SSMUtilBackend() - - def prototypes(self): - return _prototypes.PrototypeBackend() - - def hidden_model(self): - return _hidden_model.HiddenModelBackend() - - def stats(self): - return _stats.StatsBackend() - - def variable(self): - return _variable.VariableBackend()