Skip to content

Commit

Permalink
Move the Conditional type from util.cond_util to `impl._condition…
Browse files Browse the repository at this point in the history
…al` (#765)

* Move the Conditional type to impl.conditional

* Fix failing tests
  • Loading branch information
pnkraemer authored Jun 17, 2024
1 parent 5efef33 commit e7f936d
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 50 deletions.
31 changes: 21 additions & 10 deletions probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Conditionals."""

from probdiffeq.backend import abc, functools, linalg
from probdiffeq.backend import abc, containers, functools, linalg
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any, Array
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, cond_util
from probdiffeq.util import cholesky_util


class Conditional(containers.NamedTuple):
"""Conditional distributions."""

matmul: Array # or anything with a __matmul__ implementation
noise: Any # Usually a random-variable type


class ConditionalBackend(abc.ABC):
Expand All @@ -23,6 +31,9 @@ def apply(self, x, conditional, /):
def merge(self, cond1, cond2, /):
raise NotImplementedError

def conditional(self, matmul, noise):
return Conditional(matmul, noise)


class ScalarConditional(ConditionalBackend):
def marginalise(self, rv, conditional, /):
Expand All @@ -44,7 +55,7 @@ def revert(self, rv, conditional, /):

marginal = _normal.Normal(m_ext, r_ext.T)
noise = _normal.Normal(m_cond, r_bw_p.T)
return marginal, cond_util.Conditional(g_bw_p, noise)
return marginal, Conditional(g_bw_p, noise)

def apply(self, x, conditional, /):
matrix, noise = conditional
Expand All @@ -61,7 +72,7 @@ def merge(self, previous, incoming, /):
Xi = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T

noise = _normal.Normal(xi, Xi)
return cond_util.Conditional(g, noise)
return Conditional(g, noise)


class DenseConditional(ConditionalBackend):
Expand All @@ -84,7 +95,7 @@ def merge(self, cond1, cond2, /):
Xi = cholesky_util.sum_of_sqrtm_factors(
R_stack=((A @ d.cholesky).T, b.cholesky.T)
)
return cond_util.Conditional(g, _normal.Normal(xi, Xi.T))
return Conditional(g, _normal.Normal(xi, Xi.T))

def revert(self, rv, conditional, /):
matrix, noise = conditional
Expand All @@ -102,7 +113,7 @@ def revert(self, rv, conditional, /):
m_cor = mean - gain @ mean_observed
corrected = _normal.Normal(m_cor, r_cor.T)
observed = _normal.Normal(mean_observed, r_obs.T)
return observed, cond_util.Conditional(gain, corrected)
return observed, Conditional(gain, corrected)


class IsotropicConditional(ConditionalBackend):
Expand Down Expand Up @@ -133,7 +144,7 @@ def merge(self, cond1, cond2, /):
Xi = cholesky_util.sum_of_sqrtm_factors(R_stack).T

noise = _normal.Normal(xi, Xi)
return cond_util.Conditional(g, noise)
return Conditional(g, noise)

def revert(self, rv, conditional, /):
matrix, noise = conditional
Expand All @@ -149,7 +160,7 @@ def revert(self, rv, conditional, /):

extrapolated = _normal.Normal(extrapolated_mean, extrapolated_cholesky)
corrected = _normal.Normal(corrected_mean, corrected_cholesky)
return extrapolated, cond_util.Conditional(gain, corrected)
return extrapolated, Conditional(gain, corrected)


class BlockDiagConditional(ConditionalBackend):
Expand Down Expand Up @@ -185,7 +196,7 @@ def merge(self, cond1, cond2, /):
Xi = _transpose(functools.vmap(cholesky_util.sum_of_sqrtm_factors)(R_stack))

noise = _normal.Normal(xi, Xi)
return cond_util.Conditional(g, noise)
return Conditional(g, noise)

def revert(self, rv, conditional, /):
A, noise = conditional
Expand All @@ -204,7 +215,7 @@ def revert(self, rv, conditional, /):
m_cor = rv.mean - (gain @ (mean_observed[..., None]))[..., 0]
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
return observed, cond_util.Conditional(gain, corrected)
return observed, Conditional(gain, corrected)


def _transpose(matrix):
Expand Down
12 changes: 6 additions & 6 deletions probdiffeq/impl/_hidden_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from probdiffeq.backend import abc, functools
from probdiffeq.backend import numpy as np
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, cond_util, linop_util
from probdiffeq.impl import _conditional, _normal
from probdiffeq.util import cholesky_util, linop_util


class HiddenModelBackend(abc.ABC):
Expand Down Expand Up @@ -51,7 +51,7 @@ def A(x):
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return cond_util.Conditional(linop, noise)
return _conditional.Conditional(linop, noise)


class DenseHiddenModel(HiddenModelBackend):
Expand Down Expand Up @@ -90,7 +90,7 @@ def conditional_to_derivative(self, i, standard_deviation):
linop = linop_util.parametrised_linop(
lambda s, _p: self._autobatch_linop(a0)(s)
)
return cond_util.Conditional(linop, noise)
return _conditional.Conditional(linop, noise)

def _select(self, x, /, idx_or_slice):
x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F")
Expand Down Expand Up @@ -139,7 +139,7 @@ def A(x):
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return cond_util.Conditional(linop, noise)
return _conditional.Conditional(linop, noise)


class BlockDiagHiddenModel(HiddenModelBackend):
Expand Down Expand Up @@ -176,4 +176,4 @@ def A(x):
eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...]
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return cond_util.Conditional(linop, noise)
return _conditional.Conditional(linop, noise)
26 changes: 13 additions & 13 deletions probdiffeq/impl/_ssm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from probdiffeq.backend import abc, functools
from probdiffeq.backend import numpy as np
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, cond_util, ibm_util
from probdiffeq.impl import _conditional, _normal
from probdiffeq.util import cholesky_util, ibm_util


class SSMUtilBackend(abc.ABC):
Expand Down Expand Up @@ -57,7 +57,7 @@ def preconditioner_apply_cond(self, cond, p, p_inv, /):
A, noise = cond
A = p[:, None] * A * p_inv[None, :]
noise = _normal.Normal(p * noise.mean, p[:, None] * noise.cholesky)
return cond_util.Conditional(A, noise)
return _conditional.Conditional(A, noise)

def ibm_transitions(self, num_derivatives, output_scale=1.0):
a, q_sqrtm = ibm_util.system_matrices_1d(num_derivatives, output_scale)
Expand All @@ -68,7 +68,7 @@ def ibm_transitions(self, num_derivatives, output_scale=1.0):

def discretise(dt):
p, p_inv = precon_fun(dt)
return cond_util.Conditional(a, noise), (p, p_inv)
return _conditional.Conditional(a, noise), (p, p_inv)

return discretise

Expand All @@ -77,7 +77,7 @@ def identity_conditional(self, ndim, /):
mean = np.zeros((ndim,))
cov_sqrtm = np.zeros((ndim, ndim))
noise = _normal.Normal(mean, cov_sqrtm)
return cond_util.Conditional(transition, noise)
return _conditional.Conditional(transition, noise)

def standard_normal(self, ndim, /, output_scale):
mean = np.zeros((ndim,))
Expand Down Expand Up @@ -110,7 +110,7 @@ def discretise(dt):
p, p_inv = precon_fun(dt)
p = np.tile(p, d)
p_inv = np.tile(p_inv, d)
return cond_util.Conditional(A, noise), (p, p_inv)
return _conditional.Conditional(A, noise), (p, p_inv)

return discretise

Expand All @@ -121,7 +121,7 @@ def identity_conditional(self, ndim, /):
A = np.eye(n)
m = np.zeros((n,))
C = np.zeros((n, n))
return cond_util.Conditional(A, _normal.Normal(m, C))
return _conditional.Conditional(A, _normal.Normal(m, C))

def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives):
if len(tcoeffs) != num_derivatives + 1:
Expand Down Expand Up @@ -151,7 +151,7 @@ def preconditioner_apply_cond(self, cond, p, p_inv, /):
A, noise = cond
noise = self.preconditioner_apply(noise, p)
A = p[:, None] * A * p_inv[None, :]
return cond_util.Conditional(A, noise)
return _conditional.Conditional(A, noise)

def standard_normal(self, ndim, /, output_scale):
eye_n = np.eye(ndim)
Expand All @@ -178,7 +178,7 @@ def ibm_transitions(self, num_derivatives, output_scale):

def discretise(dt):
p, p_inv = precon_fun(dt)
return cond_util.Conditional(A, noise), (p, p_inv)
return _conditional.Conditional(A, noise), (p, p_inv)

return discretise

Expand All @@ -187,7 +187,7 @@ def identity_conditional(self, num_hidden_states_per_ode_dim, /):
c0 = np.zeros((num_hidden_states_per_ode_dim, num_hidden_states_per_ode_dim))
noise = _normal.Normal(m0, c0)
matrix = np.eye(num_hidden_states_per_ode_dim)
return cond_util.Conditional(matrix, noise)
return _conditional.Conditional(matrix, noise)

def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives):
if len(tcoeffs) != num_derivatives + 1:
Expand All @@ -208,7 +208,7 @@ def preconditioner_apply_cond(self, cond, p, p_inv, /):
A_new = p[:, None] * A * p_inv[None, :]

noise = _normal.Normal(p[:, None] * noise.mean, p[:, None] * noise.cholesky)
return cond_util.Conditional(A_new, noise)
return _conditional.Conditional(A_new, noise)

def standard_normal(self, num, /, output_scale):
mean = np.zeros((num, *self.ode_shape))
Expand Down Expand Up @@ -245,7 +245,7 @@ def identity_conditional(self, ndim, /):
noise = _normal.Normal(m0, c0)

matrix = np.ones((*self.ode_shape, 1, 1)) * np.eye(ndim, ndim)[None, ...]
return cond_util.Conditional(matrix, noise)
return _conditional.Conditional(matrix, noise)

def normal_from_tcoeffs(self, tcoeffs, /, num_derivatives):
if len(tcoeffs) != num_derivatives + 1:
Expand All @@ -267,7 +267,7 @@ def preconditioner_apply_cond(self, cond, p, p_inv, /):
A, noise = cond
A_new = p[None, :, None] * A * p_inv[None, None, :]
noise = self.preconditioner_apply(noise, p)
return cond_util.Conditional(A_new, noise)
return _conditional.Conditional(A_new, noise)

def standard_normal(self, ndim, output_scale):
mean = np.zeros((*self.ode_shape, ndim))
Expand Down
12 changes: 6 additions & 6 deletions probdiffeq/impl/_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from probdiffeq.backend import abc, containers, functools
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Array, Callable
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, cond_util
from probdiffeq.impl import _conditional, _normal
from probdiffeq.util import cholesky_util


class Transformation(containers.NamedTuple):
Expand Down Expand Up @@ -48,7 +48,7 @@ def revert(self, rv, transformation, /):
m_cor = rv.mean - gain * (A(rv.mean) + b)
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(A(rv.mean) + b, cholesky_obs)
return observed, cond_util.Conditional(gain, corrected)
return observed, _conditional.Conditional(gain, corrected)


class DenseTransform(TransformBackend):
Expand All @@ -72,7 +72,7 @@ def revert(self, rv, transformation, /):
m_cor = mean - gain @ (A @ mean + b)
corrected = _normal.Normal(m_cor, r_cor.T)
observed = _normal.Normal(A @ mean + b, r_obs.T)
return observed, cond_util.Conditional(gain, corrected)
return observed, _conditional.Conditional(gain, corrected)


class IsotropicTransform(TransformBackend):
Expand Down Expand Up @@ -101,7 +101,7 @@ def revert(self, rv, transformation, /):
m_cor = mean - gain * mean_observed
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
return observed, cond_util.Conditional(gain, corrected)
return observed, _conditional.Conditional(gain, corrected)


class BlockDiagTransform(TransformBackend):
Expand Down Expand Up @@ -132,7 +132,7 @@ def revert(self, rv, transformation, /):
m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0]
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
return observed, cond_util.Conditional(gain, corrected)
return observed, _conditional.Conditional(gain, corrected)


def _transpose(arr, /):
Expand Down
4 changes: 2 additions & 2 deletions probdiffeq/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any
from probdiffeq.impl import impl
from probdiffeq.util import cond_util, filter_util
from probdiffeq.util import filter_util

# TODO: the functions in here should only depend on posteriors / strategies!

Expand Down Expand Up @@ -274,4 +274,4 @@ def _markov_rescale_cholesky(markov_seq: MarkovSeq, factor) -> MarkovSeq:

def _rescale_cholesky_conditional(conditional, factor, /):
noise_new = impl.variable.rescale_cholesky(conditional.noise, factor)
return cond_util.Conditional(conditional.matmul, noise_new)
return impl.conditional.conditional(conditional.matmul, noise_new)
13 changes: 0 additions & 13 deletions probdiffeq/util/cond_util.py

This file was deleted.

0 comments on commit e7f936d

Please sign in to comment.