Skip to content

Commit

Permalink
Internal: Remove private select-style functions because we have eas…
Browse files Browse the repository at this point in the history
…ier ways of accessing derivatives now (#799)

* Remove self._select from DenseConditional

* Simplify select() for isotropic and blockdiag factorisations, too

* Delete marginal_nth_derivative because it hasn't been used anywhere
  • Loading branch information
pnkraemer authored Oct 28, 2024
1 parent 2a1b2c1 commit c5260d3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 101 deletions.
59 changes: 21 additions & 38 deletions probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def to_derivative(self, i, standard_deviation):


class DenseConditional(ConditionalBackend):
def __init__(self, ode_shape, num_derivatives, unravel):
def __init__(self, ode_shape, num_derivatives, unravel, flat_shape):
self.ode_shape = ode_shape
self.num_derivatives = num_derivatives
self.unravel = unravel
self.flat_shape = flat_shape

def apply(self, x, conditional, /):
matrix, noise = conditional
Expand Down Expand Up @@ -178,8 +179,6 @@ def revert(self, rv, conditional, /):
mean, cholesky = rv.mean, rv.cholesky

# QR-decomposition
# (todo: rename revert_conditional_noisefree to
# revert_transformation_cov_sqrt())
r_obs, (r_cor, gain) = cholesky_util.revert_conditional(
R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T
)
Expand Down Expand Up @@ -208,8 +207,7 @@ def ibm_transitions(self, *, output_scale):
A = np.kron(a, eye_d)
Q = np.kron(q_sqrtm, eye_d)

ndim = d * (self.num_derivatives + 1)
q0 = np.zeros((ndim,))
q0 = np.zeros(self.flat_shape)
noise = _normal.Normal(q0, Q)

precon_fun = preconditioner_prepare(num_derivatives=self.num_derivatives)
Expand All @@ -230,34 +228,25 @@ def preconditioner_apply(self, cond, p, p_inv, /):
return Conditional(A, noise)

def to_derivative(self, i, standard_deviation):
a0 = functools.partial(self._select, idx_or_slice=i)
x = np.zeros(self.flat_shape)

def select(a):
return self.unravel(a)[i]

linop = functools.jacrev(select)(x)

(d,) = self.ode_shape
bias = np.zeros((d,))
eye = np.eye(d)
noise = _normal.Normal(bias, standard_deviation * eye)

x = np.zeros(((self.num_derivatives + 1) * d,))
linop = _jac_materialize(lambda s, _p: self._autobatch_linop(a0)(s), inputs=x)
return Conditional(linop, noise)

def _select(self, x, /, idx_or_slice):
return self.unravel(x)[idx_or_slice]

@staticmethod
def _autobatch_linop(fun):
def fun_(x):
if np.ndim(x) > 1:
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
return fun(x)

return fun_


class IsotropicConditional(ConditionalBackend):
def __init__(self, *, ode_shape, num_derivatives):
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
self.ode_shape = ode_shape
self.num_derivatives = num_derivatives
self.unravel_tree = unravel_tree

def apply(self, x, conditional, /):
A, noise = conditional
Expand Down Expand Up @@ -332,22 +321,24 @@ def preconditioner_apply(self, cond, p, p_inv, /):
return Conditional(A_new, noise)

def to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]
def select(a):
return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0]

m = np.zeros((self.num_derivatives + 1,))
linop = functools.jacrev(select)(m)

bias = np.zeros(self.ode_shape)
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)

m = np.zeros((self.num_derivatives + 1,))
linop = _jac_materialize(lambda s, _p: A(s), inputs=m)
return Conditional(linop, noise)


class BlockDiagConditional(ConditionalBackend):
def __init__(self, *, ode_shape, num_derivatives):
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
self.ode_shape = ode_shape
self.num_derivatives = num_derivatives
self.unravel_tree = unravel_tree

def apply(self, x, conditional, /):
if np.ndim(x) == 1:
Expand Down Expand Up @@ -434,15 +425,11 @@ def preconditioner_apply(self, cond, p, p_inv, /):
return Conditional(A_new, noise)

def to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]

@functools.vmap
def lo(y):
return _jac_materialize(lambda s, _p: A(s), inputs=y)
def select(a):
return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0]

x = np.zeros((*self.ode_shape, self.num_derivatives + 1))
linop = lo(x)
linop = functools.vmap(functools.jacrev(select))(x)

bias = np.zeros((*self.ode_shape, 1))
eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...]
Expand Down Expand Up @@ -494,7 +481,3 @@ def _batch_gram(k, /):

def _binom(n, k):
return np.factorial(n) / (np.factorial(n - k) * np.factorial(k))


def _jac_materialize(func, /, *, inputs, params=None):
return functools.jacrev(lambda v: func(v, params))(inputs)
59 changes: 0 additions & 59 deletions probdiffeq/impl/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ def rescale_cholesky(self, rv, factor, /):
def qoi(self, rv):
raise NotImplementedError

@abc.abstractmethod
def marginal_nth_derivative(self, rv):
raise NotImplementedError

@abc.abstractmethod
def qoi_from_sample(self, sample, /):
raise NotImplementedError
Expand Down Expand Up @@ -105,37 +101,11 @@ def to_multivariate_normal(self, rv):
def qoi(self, rv):
return self.qoi_from_sample(rv.mean)

def marginal_nth_derivative(self, rv, i):
if rv.mean.ndim > 1:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

m = self._select(rv.mean, i)
c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i)
c = cholesky_util.triu_via_qr(c.T)
return _normal.Normal(m, c.T)

def qoi_from_sample(self, sample, /):
if np.ndim(sample) > 1:
return functools.vmap(self.qoi_from_sample)(sample)
return self.unravel(sample)

def _select(self, x, /, idx_or_slice):
x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F")
if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]:
raise ValueError
return x_reshaped[idx_or_slice]

@staticmethod
def _autobatch_linop(fun):
def fun_(x):
if np.ndim(x) > 1:
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
return fun(x)

return fun_

def update_mean(self, mean, x, /, num):
nominator = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x)
denominator = np.sqrt(num + 1)
Expand Down Expand Up @@ -198,19 +168,6 @@ def to_multivariate_normal(self, rv):
mean = rv.mean.reshape((-1,), order="F")
return (mean, cov)

def marginal_nth_derivative(self, rv, i):
if np.ndim(rv.mean) > 2:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > np.shape(rv.mean)[0]:
raise ValueError

mean = rv.mean[i, :]
cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T
return _normal.Normal(mean, cholesky)

def qoi(self, rv):
return self.qoi_from_sample(rv.mean)

Expand Down Expand Up @@ -287,22 +244,6 @@ def qoi_from_sample(self, sample, /):
return functools.vmap(self.qoi_from_sample)(sample)
return self.unravel(sample)

def marginal_nth_derivative(self, rv, i):
if np.ndim(rv.mean) > 2:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > np.shape(rv.mean)[0]:
raise ValueError

mean = rv.mean[:, i]
cholesky = functools.vmap(cholesky_util.triu_via_qr)(
(rv.cholesky[:, i, :])[..., None]
)
cholesky = np.transpose(cholesky, axes=(0, 2, 1))
return _normal.Normal(mean, cholesky)

def update_mean(self, mean, x, /, num):
if np.ndim(mean) > 0:
assert np.shape(mean) == np.shape(x)
Expand Down
11 changes: 7 additions & 4 deletions probdiffeq/impl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def choose(which: str, /, *, tcoeffs_like) -> FactImpl:

def _select_dense(*, tcoeffs_like) -> FactImpl:
ode_shape = tcoeffs_like[0].shape
_, unravel = tree_util.ravel_pytree(tcoeffs_like)
flat, unravel = tree_util.ravel_pytree(tcoeffs_like)

num_derivatives = len(tcoeffs_like) - 1

Expand All @@ -49,7 +49,10 @@ def _select_dense(*, tcoeffs_like) -> FactImpl:
linearise = _linearise.DenseLinearisation(ode_shape=ode_shape, unravel=unravel)
stats = _stats.DenseStats(ode_shape=ode_shape, unravel=unravel)
conditional = _conditional.DenseConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel=unravel
ode_shape=ode_shape,
num_derivatives=num_derivatives,
unravel=unravel,
flat_shape=flat.shape,
)
transform = _conditional.DenseTransform()
return FactImpl(
Expand All @@ -76,7 +79,7 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl:
stats = _stats.IsotropicStats(ode_shape=ode_shape, unravel=unravel)
linearise = _linearise.IsotropicLinearisation()
conditional = _conditional.IsotropicConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
)
transform = _conditional.IsotropicTransform()
return FactImpl(
Expand All @@ -103,7 +106,7 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl:
stats = _stats.BlockDiagStats(ode_shape=ode_shape, unravel=unravel)
linearise = _linearise.BlockDiagLinearisation()
conditional = _conditional.BlockDiagConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
)
transform = _conditional.BlockDiagTransform(ode_shape=ode_shape)
return FactImpl(
Expand Down

0 comments on commit c5260d3

Please sign in to comment.