Skip to content

Commit

Permalink
Condense the directory structure of probdiffeq.impl (#764)
Browse files Browse the repository at this point in the history
* Restructure impl._impl.py and move some components

* Move DenseConditional and DenseStats

* Move DenseTransform

* Move DenseVariable

* Move DenseHiddenModel

* Migrate the old type checks and error messages

* Move IsotropicPrototype

* Move IsotropicSSMUtil

* Move IsotropicVariable

* Move IsotropicStats

* Move IsotropicLinearisation

* Move IsotropicConditional

* Move IsotropicTransform

* Move IsotropicHiddenModel and thereby complete moving isotropic implementations

* Move ScalarPrototypes

* Move ScalarSSMUtil

* Move ScalarVariable

* Move ScalarStats

* Move ScalarLinearisation

* Fix an import-error in impl.variable

* Move ScalarConditional

* Move ScalarTransform

* Delete the (outdated) impl.scalar._transform.py

* Move ScalarHiddenModel and thereby complete moving Scalar implementations

* Move BlockDiagPrototype

* Delete unused files in impl.blockdiag

* Move BlockDiagSSMUtil

* Move BlockDiagVariable

* Move BlockDiagStats

* Move BlockDiagLinearisation

* Move BlockDiagConditional

* Move BlockDiagTransform

* Move BlockDiagHiddenModel and thereby complete moving BlockDiag implementations

* Correct typing-issues in impl.impl

* Improve the documentation of 'impl'
  • Loading branch information
pnkraemer authored Jun 17, 2024
1 parent bd08a9a commit 5efef33
Show file tree
Hide file tree
Showing 54 changed files with 1,422 additions and 1,718 deletions.
1 change: 1 addition & 0 deletions probdiffeq/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
"""State-space model implementation.
Refer to the quickstart for information.
[Here is a link](https://pnkraemer.github.io/probdiffeq/examples_quickstart/easy_example/).
"""
192 changes: 191 additions & 1 deletion probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Conditionals."""

from probdiffeq.backend import abc
from probdiffeq.backend import abc, functools, linalg
from probdiffeq.backend import numpy as np
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, cond_util


class ConditionalBackend(abc.ABC):
Expand All @@ -19,3 +22,190 @@ def apply(self, x, conditional, /):
@abc.abstractmethod
def merge(self, cond1, cond2, /):
raise NotImplementedError


class ScalarConditional(ConditionalBackend):
def marginalise(self, rv, conditional, /):
matrix, noise = conditional

mean = matrix @ rv.mean + noise.mean
R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T)
cholesky_T = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack)
return _normal.Normal(mean, cholesky_T.T)

def revert(self, rv, conditional, /):
matrix, noise = conditional

r_ext, (r_bw_p, g_bw_p) = cholesky_util.revert_conditional(
R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T
)
m_ext = matrix @ rv.mean + noise.mean
m_cond = rv.mean - g_bw_p @ m_ext

marginal = _normal.Normal(m_ext, r_ext.T)
noise = _normal.Normal(m_cond, r_bw_p.T)
return marginal, cond_util.Conditional(g_bw_p, noise)

def apply(self, x, conditional, /):
matrix, noise = conditional
matrix = np.squeeze(matrix)
return _normal.Normal(linalg.vector_dot(matrix, x) + noise.mean, noise.cholesky)

def merge(self, previous, incoming, /):
A, b = previous
C, d = incoming

g = A @ C
xi = A @ d.mean + b.mean
R_stack = ((A @ d.cholesky).T, b.cholesky.T)
Xi = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T

noise = _normal.Normal(xi, Xi)
return cond_util.Conditional(g, noise)


class DenseConditional(ConditionalBackend):
def apply(self, x, conditional, /):
matrix, noise = conditional
return _normal.Normal(matrix @ x + noise.mean, noise.cholesky)

def marginalise(self, rv, conditional, /):
matmul, noise = conditional
R_stack = ((matmul @ rv.cholesky).T, noise.cholesky.T)
cholesky_new = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T
return _normal.Normal(matmul @ rv.mean + noise.mean, cholesky_new)

def merge(self, cond1, cond2, /):
A, b = cond1
C, d = cond2

g = A @ C
xi = A @ d.mean + b.mean
Xi = cholesky_util.sum_of_sqrtm_factors(
R_stack=((A @ d.cholesky).T, b.cholesky.T)
)
return cond_util.Conditional(g, _normal.Normal(xi, Xi.T))

def revert(self, rv, conditional, /):
matrix, noise = conditional
mean, cholesky = rv.mean, rv.cholesky

# QR-decomposition
# (todo: rename revert_conditional_noisefree to
# revert_transformation_cov_sqrt())
r_obs, (r_cor, gain) = cholesky_util.revert_conditional(
R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T
)

# Gather terms and return
mean_observed = matrix @ mean + noise.mean
m_cor = mean - gain @ mean_observed
corrected = _normal.Normal(m_cor, r_cor.T)
observed = _normal.Normal(mean_observed, r_obs.T)
return observed, cond_util.Conditional(gain, corrected)


class IsotropicConditional(ConditionalBackend):
def apply(self, x, conditional, /):
A, noise = conditional
# if the gain is qoi-to-hidden, the data is a (d,) array.
# this is problematic for the isotropic model unless we explicitly broadcast.
if np.ndim(x) == 1:
x = x[None, :]
return _normal.Normal(A @ x + noise.mean, noise.cholesky)

def marginalise(self, rv, conditional, /):
matrix, noise = conditional

mean = matrix @ rv.mean + noise.mean

R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T)
cholesky = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T
return _normal.Normal(mean, cholesky)

def merge(self, cond1, cond2, /):
A, b = cond1
C, d = cond2

g = A @ C
xi = A @ d.mean + b.mean
R_stack = ((A @ d.cholesky).T, b.cholesky.T)
Xi = cholesky_util.sum_of_sqrtm_factors(R_stack).T

noise = _normal.Normal(xi, Xi)
return cond_util.Conditional(g, noise)

def revert(self, rv, conditional, /):
matrix, noise = conditional

r_ext_p, (r_bw_p, gain) = cholesky_util.revert_conditional(
R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T
)
extrapolated_cholesky = r_ext_p.T
corrected_cholesky = r_bw_p.T

extrapolated_mean = matrix @ rv.mean + noise.mean
corrected_mean = rv.mean - gain @ extrapolated_mean

extrapolated = _normal.Normal(extrapolated_mean, extrapolated_cholesky)
corrected = _normal.Normal(corrected_mean, corrected_cholesky)
return extrapolated, cond_util.Conditional(gain, corrected)


class BlockDiagConditional(ConditionalBackend):
def apply(self, x, conditional, /):
if np.ndim(x) == 1:
x = x[..., None]

def apply_unbatch(m, s, n):
return _normal.Normal(m @ s + n.mean, n.cholesky)

matrix, noise = conditional
return functools.vmap(apply_unbatch)(matrix, x, noise)

def marginalise(self, rv, conditional, /):
matrix, noise = conditional
assert matrix.ndim == 3

mean = np.einsum("ijk,ik->ij", matrix, rv.mean) + noise.mean

chol1 = _transpose(matrix @ rv.cholesky)
chol2 = _transpose(noise.cholesky)
R_stack = (chol1, chol2)
cholesky = functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack)
return _normal.Normal(mean, _transpose(cholesky))

def merge(self, cond1, cond2, /):
A, b = cond1
C, d = cond2

g = A @ C
xi = (A @ d.mean[..., None])[..., 0] + b.mean
R_stack = (_transpose(A @ d.cholesky), _transpose(b.cholesky))
Xi = _transpose(functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack))

noise = _normal.Normal(xi, Xi)
return cond_util.Conditional(g, noise)

def revert(self, rv, conditional, /):
A, noise = conditional
rv_chol_upper = np.transpose(rv.cholesky, axes=(0, 2, 1))
noise_chol_upper = np.transpose(noise.cholesky, axes=(0, 2, 1))
A_rv_chol_upper = np.transpose(A @ rv.cholesky, axes=(0, 2, 1))

revert = functools.vmap(cholesky_util.revert_conditional)
r_obs, (r_cor, gain) = revert(A_rv_chol_upper, rv_chol_upper, noise_chol_upper)

cholesky_obs = np.transpose(r_obs, axes=(0, 2, 1))
cholesky_cor = np.transpose(r_cor, axes=(0, 2, 1))

# Gather terms and return
mean_observed = (A @ rv.mean[..., None])[..., 0] + noise.mean
m_cor = rv.mean - (gain @ (mean_observed[..., None]))[..., 0]
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
return observed, cond_util.Conditional(gain, corrected)


def _transpose(matrix):
return np.transpose(matrix, axes=(0, 2, 1))
162 changes: 161 additions & 1 deletion probdiffeq/impl/_hidden_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from probdiffeq.backend import abc
from probdiffeq.backend import abc, functools
from probdiffeq.backend import numpy as np
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, cond_util, linop_util


class HiddenModelBackend(abc.ABC):
Expand All @@ -17,3 +20,160 @@ def qoi_from_sample(self, sample, /):
@abc.abstractmethod
def conditional_to_derivative(self, i, standard_deviation):
raise NotImplementedError


class ScalarHiddenModel(HiddenModelBackend):
def qoi(self, rv):
return rv.mean[..., 0]

def marginal_nth_derivative(self, rv, i):
if rv.mean.ndim > 1:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > rv.mean.shape[0]:
raise ValueError

m = rv.mean[i]
c = rv.cholesky[[i], :]
chol = cholesky_util.triu_via_qr(c.T)
return _normal.Normal(np.reshape(m, ()), np.reshape(chol, ()))

def qoi_from_sample(self, sample, /):
return sample[0]

def conditional_to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]

bias = np.zeros(())
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return cond_util.Conditional(linop, noise)


class DenseHiddenModel(HiddenModelBackend):
def __init__(self, ode_shape):
self.ode_shape = ode_shape

def qoi(self, rv):
if np.ndim(rv.mean) > 1:
return functools.vmap(self.qoi)(rv)
mean_reshaped = np.reshape(rv.mean, (-1, *self.ode_shape), order="F")
return mean_reshaped[0]

def marginal_nth_derivative(self, rv, i):
if rv.mean.ndim > 1:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

m = self._select(rv.mean, i)
c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i)
c = cholesky_util.triu_via_qr(c.T)
return _normal.Normal(m, c.T)

def qoi_from_sample(self, sample, /):
sample_reshaped = np.reshape(sample, (-1, *self.ode_shape), order="F")
return sample_reshaped[0]

# TODO: move to linearise.py?
def conditional_to_derivative(self, i, standard_deviation):
a0 = functools.partial(self._select, idx_or_slice=i)

(d,) = self.ode_shape
bias = np.zeros((d,))
eye = np.eye(d)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(
lambda s, _p: self._autobatch_linop(a0)(s)
)
return cond_util.Conditional(linop, noise)

def _select(self, x, /, idx_or_slice):
x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F")
if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]:
raise ValueError
return x_reshaped[idx_or_slice]

@staticmethod
def _autobatch_linop(fun):
def fun_(x):
if np.ndim(x) > 1:
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
return fun(x)

return fun_


class IsotropicHiddenModel(HiddenModelBackend):
def __init__(self, ode_shape):
self.ode_shape = ode_shape

def qoi(self, rv):
return rv.mean[..., 0, :]

def marginal_nth_derivative(self, rv, i):
if np.ndim(rv.mean) > 2:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > np.shape(rv.mean)[0]:
raise ValueError

mean = rv.mean[i, :]
cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T
return _normal.Normal(mean, cholesky)

def qoi_from_sample(self, sample, /):
return sample[0, :]

def conditional_to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]

bias = np.zeros(self.ode_shape)
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return cond_util.Conditional(linop, noise)


class BlockDiagHiddenModel(HiddenModelBackend):
def __init__(self, ode_shape):
self.ode_shape = ode_shape

def qoi(self, rv):
return rv.mean[..., 0]

def marginal_nth_derivative(self, rv, i):
if np.ndim(rv.mean) > 2:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > np.shape(rv.mean)[0]:
raise ValueError

mean = rv.mean[:, i]
cholesky = functools.vmap(cholesky_util.triu_via_qr)(
(rv.cholesky[:, i, :])[..., None]
)
cholesky = np.transpose(cholesky, axes=(0, 2, 1))
return _normal.Normal(mean, cholesky)

def qoi_from_sample(self, sample, /):
return sample[..., 0]

def conditional_to_derivative(self, i, standard_deviation):
def A(x):
return x[:, [i], ...]

bias = np.zeros((*self.ode_shape, 1))
eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...]
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return cond_util.Conditional(linop, noise)
Loading

0 comments on commit 5efef33

Please sign in to comment.