Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal: Remove private select-style functions because we have easier ways of accessing derivatives now #799

Merged
merged 3 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 21 additions & 38 deletions probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def to_derivative(self, i, standard_deviation):


class DenseConditional(ConditionalBackend):
def __init__(self, ode_shape, num_derivatives, unravel):
def __init__(self, ode_shape, num_derivatives, unravel, flat_shape):
self.ode_shape = ode_shape
self.num_derivatives = num_derivatives
self.unravel = unravel
self.flat_shape = flat_shape

def apply(self, x, conditional, /):
matrix, noise = conditional
Expand Down Expand Up @@ -178,8 +179,6 @@ def revert(self, rv, conditional, /):
mean, cholesky = rv.mean, rv.cholesky

# QR-decomposition
# (todo: rename revert_conditional_noisefree to
# revert_transformation_cov_sqrt())
r_obs, (r_cor, gain) = cholesky_util.revert_conditional(
R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T
)
Expand Down Expand Up @@ -208,8 +207,7 @@ def ibm_transitions(self, *, output_scale):
A = np.kron(a, eye_d)
Q = np.kron(q_sqrtm, eye_d)

ndim = d * (self.num_derivatives + 1)
q0 = np.zeros((ndim,))
q0 = np.zeros(self.flat_shape)
noise = _normal.Normal(q0, Q)

precon_fun = preconditioner_prepare(num_derivatives=self.num_derivatives)
Expand All @@ -230,34 +228,25 @@ def preconditioner_apply(self, cond, p, p_inv, /):
return Conditional(A, noise)

def to_derivative(self, i, standard_deviation):
a0 = functools.partial(self._select, idx_or_slice=i)
x = np.zeros(self.flat_shape)

def select(a):
return self.unravel(a)[i]

linop = functools.jacrev(select)(x)

(d,) = self.ode_shape
bias = np.zeros((d,))
eye = np.eye(d)
noise = _normal.Normal(bias, standard_deviation * eye)

x = np.zeros(((self.num_derivatives + 1) * d,))
linop = _jac_materialize(lambda s, _p: self._autobatch_linop(a0)(s), inputs=x)
return Conditional(linop, noise)

def _select(self, x, /, idx_or_slice):
return self.unravel(x)[idx_or_slice]

@staticmethod
def _autobatch_linop(fun):
def fun_(x):
if np.ndim(x) > 1:
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
return fun(x)

return fun_


class IsotropicConditional(ConditionalBackend):
def __init__(self, *, ode_shape, num_derivatives):
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
self.ode_shape = ode_shape
self.num_derivatives = num_derivatives
self.unravel_tree = unravel_tree

def apply(self, x, conditional, /):
A, noise = conditional
Expand Down Expand Up @@ -332,22 +321,24 @@ def preconditioner_apply(self, cond, p, p_inv, /):
return Conditional(A_new, noise)

def to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]
def select(a):
return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0]

m = np.zeros((self.num_derivatives + 1,))
linop = functools.jacrev(select)(m)

bias = np.zeros(self.ode_shape)
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)

m = np.zeros((self.num_derivatives + 1,))
linop = _jac_materialize(lambda s, _p: A(s), inputs=m)
return Conditional(linop, noise)


class BlockDiagConditional(ConditionalBackend):
def __init__(self, *, ode_shape, num_derivatives):
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
self.ode_shape = ode_shape
self.num_derivatives = num_derivatives
self.unravel_tree = unravel_tree

def apply(self, x, conditional, /):
if np.ndim(x) == 1:
Expand Down Expand Up @@ -434,15 +425,11 @@ def preconditioner_apply(self, cond, p, p_inv, /):
return Conditional(A_new, noise)

def to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]

@functools.vmap
def lo(y):
return _jac_materialize(lambda s, _p: A(s), inputs=y)
def select(a):
return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0]

x = np.zeros((*self.ode_shape, self.num_derivatives + 1))
linop = lo(x)
linop = functools.vmap(functools.jacrev(select))(x)

bias = np.zeros((*self.ode_shape, 1))
eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...]
Expand Down Expand Up @@ -494,7 +481,3 @@ def _batch_gram(k, /):

def _binom(n, k):
return np.factorial(n) / (np.factorial(n - k) * np.factorial(k))


def _jac_materialize(func, /, *, inputs, params=None):
return functools.jacrev(lambda v: func(v, params))(inputs)
59 changes: 0 additions & 59 deletions probdiffeq/impl/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ def rescale_cholesky(self, rv, factor, /):
def qoi(self, rv):
raise NotImplementedError

@abc.abstractmethod
def marginal_nth_derivative(self, rv):
raise NotImplementedError

@abc.abstractmethod
def qoi_from_sample(self, sample, /):
raise NotImplementedError
Expand Down Expand Up @@ -105,37 +101,11 @@ def to_multivariate_normal(self, rv):
def qoi(self, rv):
return self.qoi_from_sample(rv.mean)

def marginal_nth_derivative(self, rv, i):
if rv.mean.ndim > 1:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

m = self._select(rv.mean, i)
c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i)
c = cholesky_util.triu_via_qr(c.T)
return _normal.Normal(m, c.T)

def qoi_from_sample(self, sample, /):
if np.ndim(sample) > 1:
return functools.vmap(self.qoi_from_sample)(sample)
return self.unravel(sample)

def _select(self, x, /, idx_or_slice):
x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F")
if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]:
raise ValueError
return x_reshaped[idx_or_slice]

@staticmethod
def _autobatch_linop(fun):
def fun_(x):
if np.ndim(x) > 1:
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
return fun(x)

return fun_

def update_mean(self, mean, x, /, num):
nominator = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x)
denominator = np.sqrt(num + 1)
Expand Down Expand Up @@ -198,19 +168,6 @@ def to_multivariate_normal(self, rv):
mean = rv.mean.reshape((-1,), order="F")
return (mean, cov)

def marginal_nth_derivative(self, rv, i):
if np.ndim(rv.mean) > 2:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > np.shape(rv.mean)[0]:
raise ValueError

mean = rv.mean[i, :]
cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T
return _normal.Normal(mean, cholesky)

def qoi(self, rv):
return self.qoi_from_sample(rv.mean)

Expand Down Expand Up @@ -287,22 +244,6 @@ def qoi_from_sample(self, sample, /):
return functools.vmap(self.qoi_from_sample)(sample)
return self.unravel(sample)

def marginal_nth_derivative(self, rv, i):
if np.ndim(rv.mean) > 2:
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
rv, i
)

if i > np.shape(rv.mean)[0]:
raise ValueError

mean = rv.mean[:, i]
cholesky = functools.vmap(cholesky_util.triu_via_qr)(
(rv.cholesky[:, i, :])[..., None]
)
cholesky = np.transpose(cholesky, axes=(0, 2, 1))
return _normal.Normal(mean, cholesky)

def update_mean(self, mean, x, /, num):
if np.ndim(mean) > 0:
assert np.shape(mean) == np.shape(x)
Expand Down
11 changes: 7 additions & 4 deletions probdiffeq/impl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def choose(which: str, /, *, tcoeffs_like) -> FactImpl:

def _select_dense(*, tcoeffs_like) -> FactImpl:
ode_shape = tcoeffs_like[0].shape
_, unravel = tree_util.ravel_pytree(tcoeffs_like)
flat, unravel = tree_util.ravel_pytree(tcoeffs_like)

num_derivatives = len(tcoeffs_like) - 1

Expand All @@ -49,7 +49,10 @@ def _select_dense(*, tcoeffs_like) -> FactImpl:
linearise = _linearise.DenseLinearisation(ode_shape=ode_shape, unravel=unravel)
stats = _stats.DenseStats(ode_shape=ode_shape, unravel=unravel)
conditional = _conditional.DenseConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel=unravel
ode_shape=ode_shape,
num_derivatives=num_derivatives,
unravel=unravel,
flat_shape=flat.shape,
)
transform = _conditional.DenseTransform()
return FactImpl(
Expand All @@ -76,7 +79,7 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl:
stats = _stats.IsotropicStats(ode_shape=ode_shape, unravel=unravel)
linearise = _linearise.IsotropicLinearisation()
conditional = _conditional.IsotropicConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
)
transform = _conditional.IsotropicTransform()
return FactImpl(
Expand All @@ -103,7 +106,7 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl:
stats = _stats.BlockDiagStats(ode_shape=ode_shape, unravel=unravel)
linearise = _linearise.BlockDiagLinearisation()
conditional = _conditional.BlockDiagConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
)
transform = _conditional.BlockDiagTransform(ode_shape=ode_shape)
return FactImpl(
Expand Down
Loading