From c6ec77d3083c58005a8109a52a78129cf5b8c2e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Mon, 17 Jun 2024 09:56:06 +0200 Subject: [PATCH] Move content of `impl.variable` to `impl.stats` (#767) * Move impl.variable.transform_unit_sample to impl.stats since it classifies as 'stats-stuff' * Move impl.variable.* to impl.stats.* because variable-code is statistics-code * Delete the (now empty) impl.variable module --- probdiffeq/backend/testing.py | 4 +- probdiffeq/impl/_impl.py | 16 ------ probdiffeq/impl/_stats.py | 64 ++++++++++++++++++++++++ probdiffeq/impl/_variable.py | 86 --------------------------------- probdiffeq/ivpsolvers.py | 6 +-- probdiffeq/stats.py | 10 ++-- tests/test_impl/test_logpdfs.py | 2 +- 7 files changed, 75 insertions(+), 113 deletions(-) delete mode 100644 probdiffeq/impl/_variable.py diff --git a/probdiffeq/backend/testing.py b/probdiffeq/backend/testing.py index 4f7b8ae5..d2f87c3f 100644 --- a/probdiffeq/backend/testing.py +++ b/probdiffeq/backend/testing.py @@ -48,8 +48,8 @@ def allclose_partial(*args): def marginals_allclose(m1, m2, /): - m1, c1 = impl.variable.to_multivariate_normal(m1) - m2, c2 = impl.variable.to_multivariate_normal(m2) + m1, c1 = impl.stats.to_multivariate_normal(m1) + m2, c2 = impl.stats.to_multivariate_normal(m2) means_allclose = jnp.allclose(m1, m2) covs_allclose = jnp.allclose(c1, c2) diff --git a/probdiffeq/impl/_impl.py b/probdiffeq/impl/_impl.py index 00205e21..0ae8ad0b 100644 --- a/probdiffeq/impl/_impl.py +++ b/probdiffeq/impl/_impl.py @@ -12,7 +12,6 @@ _ssm_util, _stats, _transform, - _variable, ) @@ -69,12 +68,6 @@ def prototypes(self) -> _prototypes.PrototypeBackend: return self._fact.prototypes raise ValueError(self.error_msg()) - @property - def variable(self) -> _variable.VariableBackend: - if self._fact is not None: - return self._fact.variable - raise ValueError(self.error_msg()) - @property def hidden_model(self) -> _hidden_model.HiddenModelBackend: if self._fact is not None: @@ -96,7 +89,6 @@ def error_msg(): class FactorisedImpl: prototypes: _prototypes.PrototypeBackend ssm_util: _ssm_util.SSMUtilBackend - variable: _variable.VariableBackend stats: _stats.StatsBackend linearise: _linearise.LinearisationBackend conditional: _conditional.ConditionalBackend @@ -127,7 +119,6 @@ def choose(which: str, /, *, ode_shape=None) -> FactorisedImpl: def _select_scalar() -> FactorisedImpl: prototypes = _prototypes.ScalarPrototype() ssm_util = _ssm_util.ScalarSSMUtil() - variable = _variable.ScalarVariable() stats = _stats.ScalarStats() linearise = _linearise.ScalarLinearisation() conditional = _conditional.ScalarConditional() @@ -136,7 +127,6 @@ def _select_scalar() -> FactorisedImpl: return FactorisedImpl( prototypes=prototypes, ssm_util=ssm_util, - variable=variable, stats=stats, linearise=linearise, conditional=conditional, @@ -152,7 +142,6 @@ def _select_dense(*, ode_shape) -> FactorisedImpl: stats = _stats.DenseStats(ode_shape=ode_shape) conditional = _conditional.DenseConditional() transform = _transform.DenseTransform() - variable = _variable.DenseVariable(ode_shape=ode_shape) hidden_model = _hidden_model.DenseHiddenModel(ode_shape=ode_shape) return FactorisedImpl( linearise=linearise, @@ -160,7 +149,6 @@ def _select_dense(*, ode_shape) -> FactorisedImpl: conditional=conditional, ssm_util=ssm_util, prototypes=prototypes, - variable=variable, hidden_model=hidden_model, stats=stats, ) @@ -169,7 +157,6 @@ def _select_dense(*, ode_shape) -> FactorisedImpl: def _select_isotropic(*, ode_shape) -> FactorisedImpl: prototypes = _prototypes.IsotropicPrototype(ode_shape=ode_shape) ssm_util = _ssm_util.IsotropicSSMUtil(ode_shape=ode_shape) - variable = _variable.IsotropicVariable(ode_shape=ode_shape) stats = _stats.IsotropicStats(ode_shape=ode_shape) linearise = _linearise.IsotropicLinearisation() conditional = _conditional.IsotropicConditional() @@ -178,7 +165,6 @@ def _select_isotropic(*, ode_shape) -> FactorisedImpl: return FactorisedImpl( prototypes=prototypes, ssm_util=ssm_util, - variable=variable, stats=stats, linearise=linearise, conditional=conditional, @@ -190,7 +176,6 @@ def _select_isotropic(*, ode_shape) -> FactorisedImpl: def _select_blockdiag(*, ode_shape) -> FactorisedImpl: prototypes = _prototypes.BlockDiagPrototype(ode_shape=ode_shape) ssm_util = _ssm_util.BlockDiagSSMUtil(ode_shape=ode_shape) - variable = _variable.BlockDiagVariable(ode_shape=ode_shape) stats = _stats.BlockDiagStats(ode_shape=ode_shape) linearise = _linearise.BlockDiagLinearisation() conditional = _conditional.BlockDiagConditional() @@ -199,7 +184,6 @@ def _select_blockdiag(*, ode_shape) -> FactorisedImpl: return FactorisedImpl( prototypes=prototypes, ssm_util=ssm_util, - variable=variable, stats=stats, linearise=linearise, conditional=conditional, diff --git a/probdiffeq/impl/_stats.py b/probdiffeq/impl/_stats.py index 54bc4619..f551d36a 100644 --- a/probdiffeq/impl/_stats.py +++ b/probdiffeq/impl/_stats.py @@ -24,6 +24,18 @@ def mean(self, rv): def sample_shape(self, rv): raise NotImplementedError + @abc.abstractmethod + def transform_unit_sample(self, unit_sample, /, rv): + raise NotImplementedError + + @abc.abstractmethod + def to_multivariate_normal(self, u, rv): + raise NotImplementedError + + @abc.abstractmethod + def rescale_cholesky(self, rv, factor, /): + raise NotImplementedError + class ScalarStats(StatsBackend): def mahalanobis_norm_relative(self, u, /, rv): @@ -53,6 +65,17 @@ def mean(self, rv): def sample_shape(self, rv): return rv.mean.shape + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + rv.cholesky @ unit_sample + + def rescale_cholesky(self, rv, factor): + if np.ndim(factor) > 0: + return functools.vmap(self.rescale_cholesky)(rv, factor) + return _normal.Normal(rv.mean, factor * rv.cholesky) + + def to_multivariate_normal(self, rv): + return rv.mean, rv.cholesky @ rv.cholesky.T + class DenseStats(StatsBackend): def __init__(self, ode_shape): @@ -90,6 +113,16 @@ def standard_deviation(self, rv): def sample_shape(self, rv): return rv.mean.shape + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + rv.cholesky @ unit_sample + + def rescale_cholesky(self, rv, factor, /): + cholesky = factor[..., None, None] * rv.cholesky + return _normal.Normal(rv.mean, cholesky) + + def to_multivariate_normal(self, rv): + return rv.mean, rv.cholesky @ rv.cholesky.T + class IsotropicStats(StatsBackend): def __init__(self, ode_shape): @@ -132,6 +165,20 @@ def standard_deviation(self, rv): def sample_shape(self, rv): return rv.mean.shape + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + rv.cholesky @ unit_sample + + def rescale_cholesky(self, rv, factor, /): + cholesky = factor[..., None, None] * rv.cholesky + return _normal.Normal(rv.mean, cholesky) + + def to_multivariate_normal(self, rv): + eye_d = np.eye(*self.ode_shape) + cov = rv.cholesky @ rv.cholesky.T + cov = np.kron(eye_d, cov) + mean = rv.mean.reshape((-1,), order="F") + return (mean, cov) + class BlockDiagStats(StatsBackend): def __init__(self, ode_shape): @@ -169,3 +216,20 @@ def standard_deviation(self, rv): return functools.vmap(self.standard_deviation)(rv) return np.sqrt(linalg.vector_dot(rv.cholesky, rv.cholesky)) + + def transform_unit_sample(self, unit_sample, /, rv): + return rv.mean + (rv.cholesky @ unit_sample[..., None])[..., 0] + + def rescale_cholesky(self, rv, factor, /): + cholesky = factor[..., None, None] * rv.cholesky + return _normal.Normal(rv.mean, cholesky) + + def to_multivariate_normal(self, rv): + mean = np.reshape(rv.mean.T, (-1,), order="F") + cov = np.block_diag(self._cov_dense(rv.cholesky)) + return (mean, cov) + + def _cov_dense(self, cholesky): + if cholesky.ndim > 2: + return functools.vmap(self._cov_dense)(cholesky) + return cholesky @ cholesky.T diff --git a/probdiffeq/impl/_variable.py b/probdiffeq/impl/_variable.py deleted file mode 100644 index 0ea53b73..00000000 --- a/probdiffeq/impl/_variable.py +++ /dev/null @@ -1,86 +0,0 @@ -from probdiffeq.backend import abc, functools -from probdiffeq.backend import numpy as np -from probdiffeq.impl import _normal - - -class VariableBackend(abc.ABC): - @abc.abstractmethod - def to_multivariate_normal(self, u, rv): - raise NotImplementedError - - @abc.abstractmethod - def rescale_cholesky(self, rv, factor, /): - raise NotImplementedError - - @abc.abstractmethod - def transform_unit_sample(self, unit_sample, /, rv): - raise NotImplementedError - - -class ScalarVariable(VariableBackend): - def rescale_cholesky(self, rv, factor): - if np.ndim(factor) > 0: - return functools.vmap(self.rescale_cholesky)(rv, factor) - return _normal.Normal(rv.mean, factor * rv.cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + rv.cholesky @ unit_sample - - def to_multivariate_normal(self, rv): - return rv.mean, rv.cholesky @ rv.cholesky.T - - -class DenseVariable(VariableBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def rescale_cholesky(self, rv, factor, /): - cholesky = factor[..., None, None] * rv.cholesky - return _normal.Normal(rv.mean, cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + rv.cholesky @ unit_sample - - def to_multivariate_normal(self, rv): - return rv.mean, rv.cholesky @ rv.cholesky.T - - -class IsotropicVariable(VariableBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def rescale_cholesky(self, rv, factor, /): - cholesky = factor[..., None, None] * rv.cholesky - return _normal.Normal(rv.mean, cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + rv.cholesky @ unit_sample - - def to_multivariate_normal(self, rv): - eye_d = np.eye(*self.ode_shape) - cov = rv.cholesky @ rv.cholesky.T - cov = np.kron(eye_d, cov) - mean = rv.mean.reshape((-1,), order="F") - return (mean, cov) - - -class BlockDiagVariable(VariableBackend): - def __init__(self, ode_shape): - self.ode_shape = ode_shape - - def rescale_cholesky(self, rv, factor, /): - cholesky = factor[..., None, None] * rv.cholesky - return _normal.Normal(rv.mean, cholesky) - - def transform_unit_sample(self, unit_sample, /, rv): - return rv.mean + (rv.cholesky @ unit_sample[..., None])[..., 0] - - def to_multivariate_normal(self, rv): - mean = np.reshape(rv.mean.T, (-1,), order="F") - cov = np.block_diag(self._cov_dense(rv.cholesky)) - return (mean, cov) - - def _cov_dense(self, cholesky): - if cholesky.ndim > 2: - return functools.vmap(self._cov_dense)(cholesky) - return cholesky @ cholesky.T diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 815f9b23..e98806f9 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -284,7 +284,7 @@ def complete(self, _ssv, extra, /, output_scale): # Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity) A, noise = cond - noise = impl.variable.rescale_cholesky(noise, output_scale) + noise = impl.stats.rescale_cholesky(noise, output_scale) extrapolated_p, cond_p = impl.conditional.revert(rv_p, (A, noise)) extrapolated = impl.ssm_util.preconditioner_apply(extrapolated_p, p) cond = impl.ssm_util.preconditioner_apply_cond(cond_p, p, p_inv) @@ -377,7 +377,7 @@ def complete(self, _rv, extra, /, output_scale): # Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity) A, noise = cond - noise = impl.variable.rescale_cholesky(noise, output_scale) + noise = impl.stats.rescale_cholesky(noise, output_scale) extrapolated_p, cond_p = impl.conditional.revert(rv_p, (A, noise)) extrapolated = impl.ssm_util.preconditioner_apply(extrapolated_p, p) cond = impl.ssm_util.preconditioner_apply_cond(cond_p, p, p_inv) @@ -503,7 +503,7 @@ def complete(self, _ssv, extra, /, output_scale): # Extrapolate the Cholesky factor (re-extrapolate the mean for simplicity) A, noise = cond - noise = impl.variable.rescale_cholesky(noise, output_scale) + noise = impl.stats.rescale_cholesky(noise, output_scale) extrapolated_p = impl.conditional.marginalise(rv_p, (A, noise)) extrapolated = impl.ssm_util.preconditioner_apply(extrapolated_p, p) diff --git a/probdiffeq/stats.py b/probdiffeq/stats.py index d3cc1dc1..ed6c3184 100644 --- a/probdiffeq/stats.py +++ b/probdiffeq/stats.py @@ -51,14 +51,14 @@ def body_fun(carry, conditionals_and_base_samples): conditional, base = conditionals_and_base_samples rv = impl.conditional.apply(samp_prev, conditional) - smp = impl.variable.transform_unit_sample(base, rv) + smp = impl.stats.transform_unit_sample(base, rv) qoi = impl.hidden_model.qoi_from_sample(smp) return (qoi, smp), (qoi, smp) base_sample_init, base_sample_body = base_sample[0], base_sample[1:] # Compute a sample at the terminal value - init_sample = impl.variable.transform_unit_sample(base_sample_init, markov_seq.init) + init_sample = impl.stats.transform_unit_sample(base_sample_init, markov_seq.init) init_qoi = impl.hidden_model.qoi_from_sample(init_sample) init_val = (init_qoi, init_sample) @@ -262,16 +262,16 @@ def calibrate(x, /, output_scale): output_scale = output_scale[-1] if isinstance(x, MarkovSeq): return _markov_rescale_cholesky(x, output_scale) - return impl.variable.rescale_cholesky(x, output_scale) + return impl.stats.rescale_cholesky(x, output_scale) def _markov_rescale_cholesky(markov_seq: MarkovSeq, factor) -> MarkovSeq: """Rescale the Cholesky factor of the covariance of a Markov sequence.""" - init = impl.variable.rescale_cholesky(markov_seq.init, factor) + init = impl.stats.rescale_cholesky(markov_seq.init, factor) cond = _rescale_cholesky_conditional(markov_seq.conditional, factor) return MarkovSeq(init=init, conditional=cond) def _rescale_cholesky_conditional(conditional, factor, /): - noise_new = impl.variable.rescale_cholesky(conditional.noise, factor) + noise_new = impl.stats.rescale_cholesky(conditional.noise, factor) return impl.conditional.conditional(conditional.matmul, noise_new) diff --git a/tests/test_impl/test_logpdfs.py b/tests/test_impl/test_logpdfs.py index f49e4f5f..4622661d 100644 --- a/tests/test_impl/test_logpdfs.py +++ b/tests/test_impl/test_logpdfs.py @@ -11,7 +11,7 @@ def test_logpdf(): rv = setup.rv() - (mean_dense, cov_dense) = impl.variable.to_multivariate_normal(rv) + (mean_dense, cov_dense) = impl.stats.to_multivariate_normal(rv) u = np.ones_like(impl.stats.mean(rv)) u_dense = np.ones_like(mean_dense)