diff --git a/probdiffeq/impl/_conditional.py b/probdiffeq/impl/_conditional.py index 3c31328b..52c0c643 100644 --- a/probdiffeq/impl/_conditional.py +++ b/probdiffeq/impl/_conditional.py @@ -147,10 +147,11 @@ def to_derivative(self, i, standard_deviation): class DenseConditional(ConditionalBackend): - def __init__(self, ode_shape, num_derivatives, unravel): + def __init__(self, ode_shape, num_derivatives, unravel, flat_shape): self.ode_shape = ode_shape self.num_derivatives = num_derivatives self.unravel = unravel + self.flat_shape = flat_shape def apply(self, x, conditional, /): matrix, noise = conditional @@ -178,8 +179,6 @@ def revert(self, rv, conditional, /): mean, cholesky = rv.mean, rv.cholesky # QR-decomposition - # (todo: rename revert_conditional_noisefree to - # revert_transformation_cov_sqrt()) r_obs, (r_cor, gain) = cholesky_util.revert_conditional( R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T ) @@ -208,8 +207,7 @@ def ibm_transitions(self, *, output_scale): A = np.kron(a, eye_d) Q = np.kron(q_sqrtm, eye_d) - ndim = d * (self.num_derivatives + 1) - q0 = np.zeros((ndim,)) + q0 = np.zeros(self.flat_shape) noise = _normal.Normal(q0, Q) precon_fun = preconditioner_prepare(num_derivatives=self.num_derivatives) @@ -230,34 +228,25 @@ def preconditioner_apply(self, cond, p, p_inv, /): return Conditional(A, noise) def to_derivative(self, i, standard_deviation): - a0 = functools.partial(self._select, idx_or_slice=i) + x = np.zeros(self.flat_shape) + + def select(a): + return self.unravel(a)[i] + + linop = functools.jacrev(select)(x) (d,) = self.ode_shape bias = np.zeros((d,)) eye = np.eye(d) noise = _normal.Normal(bias, standard_deviation * eye) - - x = np.zeros(((self.num_derivatives + 1) * d,)) - linop = _jac_materialize(lambda s, _p: self._autobatch_linop(a0)(s), inputs=x) return Conditional(linop, noise) - def _select(self, x, /, idx_or_slice): - return self.unravel(x)[idx_or_slice] - - @staticmethod - def _autobatch_linop(fun): - def fun_(x): - if np.ndim(x) > 1: - return functools.vmap(fun_, in_axes=1, out_axes=1)(x) - return fun(x) - - return fun_ - class IsotropicConditional(ConditionalBackend): - def __init__(self, *, ode_shape, num_derivatives): + def __init__(self, *, ode_shape, num_derivatives, unravel_tree): self.ode_shape = ode_shape self.num_derivatives = num_derivatives + self.unravel_tree = unravel_tree def apply(self, x, conditional, /): A, noise = conditional @@ -332,22 +321,24 @@ def preconditioner_apply(self, cond, p, p_inv, /): return Conditional(A_new, noise) def to_derivative(self, i, standard_deviation): - def A(x): - return x[[i], ...] + def select(a): + return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0] + + m = np.zeros((self.num_derivatives + 1,)) + linop = functools.jacrev(select)(m) bias = np.zeros(self.ode_shape) eye = np.eye(1) noise = _normal.Normal(bias, standard_deviation * eye) - m = np.zeros((self.num_derivatives + 1,)) - linop = _jac_materialize(lambda s, _p: A(s), inputs=m) return Conditional(linop, noise) class BlockDiagConditional(ConditionalBackend): - def __init__(self, *, ode_shape, num_derivatives): + def __init__(self, *, ode_shape, num_derivatives, unravel_tree): self.ode_shape = ode_shape self.num_derivatives = num_derivatives + self.unravel_tree = unravel_tree def apply(self, x, conditional, /): if np.ndim(x) == 1: @@ -434,15 +425,11 @@ def preconditioner_apply(self, cond, p, p_inv, /): return Conditional(A_new, noise) def to_derivative(self, i, standard_deviation): - def A(x): - return x[[i], ...] - - @functools.vmap - def lo(y): - return _jac_materialize(lambda s, _p: A(s), inputs=y) + def select(a): + return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0] x = np.zeros((*self.ode_shape, self.num_derivatives + 1)) - linop = lo(x) + linop = functools.vmap(functools.jacrev(select))(x) bias = np.zeros((*self.ode_shape, 1)) eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...] @@ -494,7 +481,3 @@ def _batch_gram(k, /): def _binom(n, k): return np.factorial(n) / (np.factorial(n - k) * np.factorial(k)) - - -def _jac_materialize(func, /, *, inputs, params=None): - return functools.jacrev(lambda v: func(v, params))(inputs) diff --git a/probdiffeq/impl/_stats.py b/probdiffeq/impl/_stats.py index 2f744fa2..b4721cca 100644 --- a/probdiffeq/impl/_stats.py +++ b/probdiffeq/impl/_stats.py @@ -42,10 +42,6 @@ def rescale_cholesky(self, rv, factor, /): def qoi(self, rv): raise NotImplementedError - @abc.abstractmethod - def marginal_nth_derivative(self, rv): - raise NotImplementedError - @abc.abstractmethod def qoi_from_sample(self, sample, /): raise NotImplementedError @@ -105,37 +101,11 @@ def to_multivariate_normal(self, rv): def qoi(self, rv): return self.qoi_from_sample(rv.mean) - def marginal_nth_derivative(self, rv, i): - if rv.mean.ndim > 1: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - m = self._select(rv.mean, i) - c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i) - c = cholesky_util.triu_via_qr(c.T) - return _normal.Normal(m, c.T) - def qoi_from_sample(self, sample, /): if np.ndim(sample) > 1: return functools.vmap(self.qoi_from_sample)(sample) return self.unravel(sample) - def _select(self, x, /, idx_or_slice): - x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F") - if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]: - raise ValueError - return x_reshaped[idx_or_slice] - - @staticmethod - def _autobatch_linop(fun): - def fun_(x): - if np.ndim(x) > 1: - return functools.vmap(fun_, in_axes=1, out_axes=1)(x) - return fun(x) - - return fun_ - def update_mean(self, mean, x, /, num): nominator = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x) denominator = np.sqrt(num + 1) @@ -198,19 +168,6 @@ def to_multivariate_normal(self, rv): mean = rv.mean.reshape((-1,), order="F") return (mean, cov) - def marginal_nth_derivative(self, rv, i): - if np.ndim(rv.mean) > 2: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - if i > np.shape(rv.mean)[0]: - raise ValueError - - mean = rv.mean[i, :] - cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T - return _normal.Normal(mean, cholesky) - def qoi(self, rv): return self.qoi_from_sample(rv.mean) @@ -287,22 +244,6 @@ def qoi_from_sample(self, sample, /): return functools.vmap(self.qoi_from_sample)(sample) return self.unravel(sample) - def marginal_nth_derivative(self, rv, i): - if np.ndim(rv.mean) > 2: - return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))( - rv, i - ) - - if i > np.shape(rv.mean)[0]: - raise ValueError - - mean = rv.mean[:, i] - cholesky = functools.vmap(cholesky_util.triu_via_qr)( - (rv.cholesky[:, i, :])[..., None] - ) - cholesky = np.transpose(cholesky, axes=(0, 2, 1)) - return _normal.Normal(mean, cholesky) - def update_mean(self, mean, x, /, num): if np.ndim(mean) > 0: assert np.shape(mean) == np.shape(x) diff --git a/probdiffeq/impl/impl.py b/probdiffeq/impl/impl.py index 8c23b3c9..c4db1235 100644 --- a/probdiffeq/impl/impl.py +++ b/probdiffeq/impl/impl.py @@ -40,7 +40,7 @@ def choose(which: str, /, *, tcoeffs_like) -> FactImpl: def _select_dense(*, tcoeffs_like) -> FactImpl: ode_shape = tcoeffs_like[0].shape - _, unravel = tree_util.ravel_pytree(tcoeffs_like) + flat, unravel = tree_util.ravel_pytree(tcoeffs_like) num_derivatives = len(tcoeffs_like) - 1 @@ -49,7 +49,10 @@ def _select_dense(*, tcoeffs_like) -> FactImpl: linearise = _linearise.DenseLinearisation(ode_shape=ode_shape, unravel=unravel) stats = _stats.DenseStats(ode_shape=ode_shape, unravel=unravel) conditional = _conditional.DenseConditional( - ode_shape=ode_shape, num_derivatives=num_derivatives, unravel=unravel + ode_shape=ode_shape, + num_derivatives=num_derivatives, + unravel=unravel, + flat_shape=flat.shape, ) transform = _conditional.DenseTransform() return FactImpl( @@ -76,7 +79,7 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl: stats = _stats.IsotropicStats(ode_shape=ode_shape, unravel=unravel) linearise = _linearise.IsotropicLinearisation() conditional = _conditional.IsotropicConditional( - ode_shape=ode_shape, num_derivatives=num_derivatives + ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree ) transform = _conditional.IsotropicTransform() return FactImpl( @@ -103,7 +106,7 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl: stats = _stats.BlockDiagStats(ode_shape=ode_shape, unravel=unravel) linearise = _linearise.BlockDiagLinearisation() conditional = _conditional.BlockDiagConditional( - ode_shape=ode_shape, num_derivatives=num_derivatives + ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree ) transform = _conditional.BlockDiagTransform(ode_shape=ode_shape) return FactImpl(