Skip to content

Commit

Permalink
Move content of impl.variable to impl.stats (#767)
Browse files Browse the repository at this point in the history
* Move impl.variable.transform_unit_sample to impl.stats since it classifies as 'stats-stuff'

* Move impl.variable.* to impl.stats.* because variable-code is statistics-code

* Delete the (now empty) impl.variable module
  • Loading branch information
pnkraemer authored Jun 17, 2024
1 parent 4ed9354 commit c6ec77d
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 113 deletions.
4 changes: 2 additions & 2 deletions probdiffeq/backend/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def allclose_partial(*args):


def marginals_allclose(m1, m2, /):
m1, c1 = impl.variable.to_multivariate_normal(m1)
m2, c2 = impl.variable.to_multivariate_normal(m2)
m1, c1 = impl.stats.to_multivariate_normal(m1)
m2, c2 = impl.stats.to_multivariate_normal(m2)

means_allclose = jnp.allclose(m1, m2)
covs_allclose = jnp.allclose(c1, c2)
Expand Down
16 changes: 0 additions & 16 deletions probdiffeq/impl/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
_ssm_util,
_stats,
_transform,
_variable,
)


Expand Down Expand Up @@ -69,12 +68,6 @@ def prototypes(self) -> _prototypes.PrototypeBackend:
return self._fact.prototypes
raise ValueError(self.error_msg())

@property
def variable(self) -> _variable.VariableBackend:
if self._fact is not None:
return self._fact.variable
raise ValueError(self.error_msg())

@property
def hidden_model(self) -> _hidden_model.HiddenModelBackend:
if self._fact is not None:
Expand All @@ -96,7 +89,6 @@ def error_msg():
class FactorisedImpl:
prototypes: _prototypes.PrototypeBackend
ssm_util: _ssm_util.SSMUtilBackend
variable: _variable.VariableBackend
stats: _stats.StatsBackend
linearise: _linearise.LinearisationBackend
conditional: _conditional.ConditionalBackend
Expand Down Expand Up @@ -127,7 +119,6 @@ def choose(which: str, /, *, ode_shape=None) -> FactorisedImpl:
def _select_scalar() -> FactorisedImpl:
prototypes = _prototypes.ScalarPrototype()
ssm_util = _ssm_util.ScalarSSMUtil()
variable = _variable.ScalarVariable()
stats = _stats.ScalarStats()
linearise = _linearise.ScalarLinearisation()
conditional = _conditional.ScalarConditional()
Expand All @@ -136,7 +127,6 @@ def _select_scalar() -> FactorisedImpl:
return FactorisedImpl(
prototypes=prototypes,
ssm_util=ssm_util,
variable=variable,
stats=stats,
linearise=linearise,
conditional=conditional,
Expand All @@ -152,15 +142,13 @@ def _select_dense(*, ode_shape) -> FactorisedImpl:
stats = _stats.DenseStats(ode_shape=ode_shape)
conditional = _conditional.DenseConditional()
transform = _transform.DenseTransform()
variable = _variable.DenseVariable(ode_shape=ode_shape)
hidden_model = _hidden_model.DenseHiddenModel(ode_shape=ode_shape)
return FactorisedImpl(
linearise=linearise,
transform=transform,
conditional=conditional,
ssm_util=ssm_util,
prototypes=prototypes,
variable=variable,
hidden_model=hidden_model,
stats=stats,
)
Expand All @@ -169,7 +157,6 @@ def _select_dense(*, ode_shape) -> FactorisedImpl:
def _select_isotropic(*, ode_shape) -> FactorisedImpl:
prototypes = _prototypes.IsotropicPrototype(ode_shape=ode_shape)
ssm_util = _ssm_util.IsotropicSSMUtil(ode_shape=ode_shape)
variable = _variable.IsotropicVariable(ode_shape=ode_shape)
stats = _stats.IsotropicStats(ode_shape=ode_shape)
linearise = _linearise.IsotropicLinearisation()
conditional = _conditional.IsotropicConditional()
Expand All @@ -178,7 +165,6 @@ def _select_isotropic(*, ode_shape) -> FactorisedImpl:
return FactorisedImpl(
prototypes=prototypes,
ssm_util=ssm_util,
variable=variable,
stats=stats,
linearise=linearise,
conditional=conditional,
Expand All @@ -190,7 +176,6 @@ def _select_isotropic(*, ode_shape) -> FactorisedImpl:
def _select_blockdiag(*, ode_shape) -> FactorisedImpl:
prototypes = _prototypes.BlockDiagPrototype(ode_shape=ode_shape)
ssm_util = _ssm_util.BlockDiagSSMUtil(ode_shape=ode_shape)
variable = _variable.BlockDiagVariable(ode_shape=ode_shape)
stats = _stats.BlockDiagStats(ode_shape=ode_shape)
linearise = _linearise.BlockDiagLinearisation()
conditional = _conditional.BlockDiagConditional()
Expand All @@ -199,7 +184,6 @@ def _select_blockdiag(*, ode_shape) -> FactorisedImpl:
return FactorisedImpl(
prototypes=prototypes,
ssm_util=ssm_util,
variable=variable,
stats=stats,
linearise=linearise,
conditional=conditional,
Expand Down
64 changes: 64 additions & 0 deletions probdiffeq/impl/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def mean(self, rv):
def sample_shape(self, rv):
raise NotImplementedError

@abc.abstractmethod
def transform_unit_sample(self, unit_sample, /, rv):
raise NotImplementedError

@abc.abstractmethod
def to_multivariate_normal(self, u, rv):
raise NotImplementedError

@abc.abstractmethod
def rescale_cholesky(self, rv, factor, /):
raise NotImplementedError


class ScalarStats(StatsBackend):
def mahalanobis_norm_relative(self, u, /, rv):
Expand Down Expand Up @@ -53,6 +65,17 @@ def mean(self, rv):
def sample_shape(self, rv):
return rv.mean.shape

def transform_unit_sample(self, unit_sample, /, rv):
return rv.mean + rv.cholesky @ unit_sample

def rescale_cholesky(self, rv, factor):
if np.ndim(factor) > 0:
return functools.vmap(self.rescale_cholesky)(rv, factor)
return _normal.Normal(rv.mean, factor * rv.cholesky)

def to_multivariate_normal(self, rv):
return rv.mean, rv.cholesky @ rv.cholesky.T


class DenseStats(StatsBackend):
def __init__(self, ode_shape):
Expand Down Expand Up @@ -90,6 +113,16 @@ def standard_deviation(self, rv):
def sample_shape(self, rv):
return rv.mean.shape

def transform_unit_sample(self, unit_sample, /, rv):
return rv.mean + rv.cholesky @ unit_sample

def rescale_cholesky(self, rv, factor, /):
cholesky = factor[..., None, None] * rv.cholesky
return _normal.Normal(rv.mean, cholesky)

def to_multivariate_normal(self, rv):
return rv.mean, rv.cholesky @ rv.cholesky.T


class IsotropicStats(StatsBackend):
def __init__(self, ode_shape):
Expand Down Expand Up @@ -132,6 +165,20 @@ def standard_deviation(self, rv):
def sample_shape(self, rv):
return rv.mean.shape

def transform_unit_sample(self, unit_sample, /, rv):
return rv.mean + rv.cholesky @ unit_sample

def rescale_cholesky(self, rv, factor, /):
cholesky = factor[..., None, None] * rv.cholesky
return _normal.Normal(rv.mean, cholesky)

def to_multivariate_normal(self, rv):
eye_d = np.eye(*self.ode_shape)
cov = rv.cholesky @ rv.cholesky.T
cov = np.kron(eye_d, cov)
mean = rv.mean.reshape((-1,), order="F")
return (mean, cov)


class BlockDiagStats(StatsBackend):
def __init__(self, ode_shape):
Expand Down Expand Up @@ -169,3 +216,20 @@ def standard_deviation(self, rv):
return functools.vmap(self.standard_deviation)(rv)

return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky))

def transform_unit_sample(self, unit_sample, /, rv):
return rv.mean + (rv.cholesky @ unit_sample[..., None])[..., 0]

def rescale_cholesky(self, rv, factor, /):
cholesky = factor[..., None, None] * rv.cholesky
return _normal.Normal(rv.mean, cholesky)

def to_multivariate_normal(self, rv):
mean = np.reshape(rv.mean.T, (-1,), order="F")
cov = np.block_diag(self._cov_dense(rv.cholesky))
return (mean, cov)

def _cov_dense(self, cholesky):
if cholesky.ndim > 2:
return functools.vmap(self._cov_dense)(cholesky)
return cholesky @ cholesky.T
86 changes: 0 additions & 86 deletions probdiffeq/impl/_variable.py

This file was deleted.

6 changes: 3 additions & 3 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def complete(self, _ssv, extra, /, output_scale):

# Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity)
A, noise = cond
noise = impl.variable.rescale_cholesky(noise, output_scale)
noise = impl.stats.rescale_cholesky(noise, output_scale)
extrapolated_p, cond_p = impl.conditional.revert(rv_p, (A, noise))
extrapolated = impl.ssm_util.preconditioner_apply(extrapolated_p, p)
cond = impl.ssm_util.preconditioner_apply_cond(cond_p, p, p_inv)
Expand Down Expand Up @@ -377,7 +377,7 @@ def complete(self, _rv, extra, /, output_scale):

# Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity)
A, noise = cond
noise = impl.variable.rescale_cholesky(noise, output_scale)
noise = impl.stats.rescale_cholesky(noise, output_scale)
extrapolated_p, cond_p = impl.conditional.revert(rv_p, (A, noise))
extrapolated = impl.ssm_util.preconditioner_apply(extrapolated_p, p)
cond = impl.ssm_util.preconditioner_apply_cond(cond_p, p, p_inv)
Expand Down Expand Up @@ -503,7 +503,7 @@ def complete(self, _ssv, extra, /, output_scale):

# Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity)
A, noise = cond
noise = impl.variable.rescale_cholesky(noise, output_scale)
noise = impl.stats.rescale_cholesky(noise, output_scale)
extrapolated_p = impl.conditional.marginalise(rv_p, (A, noise))
extrapolated = impl.ssm_util.preconditioner_apply(extrapolated_p, p)

Expand Down
10 changes: 5 additions & 5 deletions probdiffeq/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def body_fun(carry, conditionals_and_base_samples):
conditional, base = conditionals_and_base_samples

rv = impl.conditional.apply(samp_prev, conditional)
smp = impl.variable.transform_unit_sample(base, rv)
smp = impl.stats.transform_unit_sample(base, rv)
qoi = impl.hidden_model.qoi_from_sample(smp)
return (qoi, smp), (qoi, smp)

base_sample_init, base_sample_body = base_sample[0], base_sample[1:]

# Compute a sample at the terminal value
init_sample = impl.variable.transform_unit_sample(base_sample_init, markov_seq.init)
init_sample = impl.stats.transform_unit_sample(base_sample_init, markov_seq.init)
init_qoi = impl.hidden_model.qoi_from_sample(init_sample)
init_val = (init_qoi, init_sample)

Expand Down Expand Up @@ -262,16 +262,16 @@ def calibrate(x, /, output_scale):
output_scale = output_scale[-1]
if isinstance(x, MarkovSeq):
return _markov_rescale_cholesky(x, output_scale)
return impl.variable.rescale_cholesky(x, output_scale)
return impl.stats.rescale_cholesky(x, output_scale)


def _markov_rescale_cholesky(markov_seq: MarkovSeq, factor) -> MarkovSeq:
"""Rescale the Cholesky factor of the covariance of a Markov sequence."""
init = impl.variable.rescale_cholesky(markov_seq.init, factor)
init = impl.stats.rescale_cholesky(markov_seq.init, factor)
cond = _rescale_cholesky_conditional(markov_seq.conditional, factor)
return MarkovSeq(init=init, conditional=cond)


def _rescale_cholesky_conditional(conditional, factor, /):
noise_new = impl.variable.rescale_cholesky(conditional.noise, factor)
noise_new = impl.stats.rescale_cholesky(conditional.noise, factor)
return impl.conditional.conditional(conditional.matmul, noise_new)
2 changes: 1 addition & 1 deletion tests/test_impl/test_logpdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def test_logpdf():
rv = setup.rv()
(mean_dense, cov_dense) = impl.variable.to_multivariate_normal(rv)
(mean_dense, cov_dense) = impl.stats.to_multivariate_normal(rv)

u = np.ones_like(impl.stats.mean(rv))
u_dense = np.ones_like(mean_dense)
Expand Down

0 comments on commit c6ec77d

Please sign in to comment.