From 31f027ab3160078433a26a74807bf3e1bdeee499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Fri, 25 Oct 2024 11:43:22 +0200 Subject: [PATCH] Replace probdiffeq.util.linop_util with jax.Arrays to improve hackability (#787) * All linop_util.LinOp's are jax.Arrays now * Delete linop_utils.py because this single line of code is more readable if repeated in each module * Rename parametrise_linop to jac_materialize because it better represents what the function does --- probdiffeq/impl/_conditional.py | 108 +++++++------------------------- probdiffeq/impl/_linearise.py | 30 ++++++--- probdiffeq/util/linop_util.py | 34 ---------- 3 files changed, 43 insertions(+), 129 deletions(-) delete mode 100644 probdiffeq/util/linop_util.py diff --git a/probdiffeq/impl/_conditional.py b/probdiffeq/impl/_conditional.py index 08158756..3c31328b 100644 --- a/probdiffeq/impl/_conditional.py +++ b/probdiffeq/impl/_conditional.py @@ -4,7 +4,7 @@ from probdiffeq.backend import numpy as np from probdiffeq.backend.typing import Any, Array from probdiffeq.impl import _normal -from probdiffeq.util import cholesky_util, linop_util +from probdiffeq.util import cholesky_util class TransformBackend(abc.ABC): @@ -80,7 +80,8 @@ def marginalise(self, rv, transformation, /): A_cholesky = A @ cholesky cholesky = functools.vmap(cholesky_util.triu_via_qr)(_transpose(A_cholesky)) - mean = A @ mean + b + + mean = functools.vmap(lambda x, y, z: x @ y + z)(A, mean, b) return _normal.Normal(mean, cholesky) def revert(self, rv, transformation, /): @@ -94,7 +95,7 @@ def revert(self, rv, transformation, /): cholesky_cor = _transpose(r_cor) # Gather terms and return - mean_observed = (A @ rv.mean) + bias + mean_observed = functools.vmap(lambda x, y, z: x @ y + z)(A, rv.mean, bias) m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0] corrected = _normal.Normal(m_cor, cholesky_cor) observed = _normal.Normal(mean_observed, cholesky_obs) @@ -145,82 +146,6 @@ def to_derivative(self, i, standard_deviation): raise NotImplementedError -class ScalarConditional(ConditionalBackend): - def marginalise(self, rv, conditional, /): - matrix, noise = conditional - - mean = matrix @ rv.mean + noise.mean - R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T) - cholesky_T = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack) - return _normal.Normal(mean, cholesky_T.T) - - def revert(self, rv, conditional, /): - matrix, noise = conditional - - r_ext, (r_bw_p, g_bw_p) = cholesky_util.revert_conditional( - R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T - ) - m_ext = matrix @ rv.mean + noise.mean - m_cond = rv.mean - g_bw_p @ m_ext - - marginal = _normal.Normal(m_ext, r_ext.T) - noise = _normal.Normal(m_cond, r_bw_p.T) - return marginal, Conditional(g_bw_p, noise) - - def apply(self, x, conditional, /): - matrix, noise = conditional - matrix = np.squeeze(matrix) - return _normal.Normal(linalg.vector_dot(matrix, x) + noise.mean, noise.cholesky) - - def merge(self, previous, incoming, /): - A, b = previous - C, d = incoming - - g = A @ C - xi = A @ d.mean + b.mean - R_stack = ((A @ d.cholesky).T, b.cholesky.T) - Xi = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T - - noise = _normal.Normal(xi, Xi) - return Conditional(g, noise) - - def identity(self, ndim, /) -> Conditional: - transition = np.eye(ndim) - mean = np.zeros((ndim,)) - cov_sqrtm = np.zeros((ndim, ndim)) - noise = _normal.Normal(mean, cov_sqrtm) - return Conditional(transition, noise) - - def ibm_transitions(self, num_derivatives, output_scale=1.0): - a, q_sqrtm = system_matrices_1d(num_derivatives, output_scale) - q0 = np.zeros((num_derivatives + 1,)) - noise = _normal.Normal(q0, q_sqrtm) - - precon_fun = preconditioner_prepare(num_derivatives=num_derivatives) - - def discretise(dt): - p, p_inv = precon_fun(dt) - return Conditional(a, noise), (p, p_inv) - - return discretise - - def preconditioner_apply(self, cond, p, p_inv, /): - A, noise = cond - A = p[:, None] * A * p_inv[None, :] - noise = _normal.Normal(p * noise.mean, p[:, None] * noise.cholesky) - return Conditional(A, noise) - - def to_derivative(self, i, standard_deviation): - def A(x): - return x[[i], ...] - - bias = np.zeros(()) - eye = np.eye(1) - noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: A(s)) - return Conditional(linop, noise) - - class DenseConditional(ConditionalBackend): def __init__(self, ode_shape, num_derivatives, unravel): self.ode_shape = ode_shape @@ -311,9 +236,9 @@ def to_derivative(self, i, standard_deviation): bias = np.zeros((d,)) eye = np.eye(d) noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop( - lambda s, _p: self._autobatch_linop(a0)(s) - ) + + x = np.zeros(((self.num_derivatives + 1) * d,)) + linop = _jac_materialize(lambda s, _p: self._autobatch_linop(a0)(s), inputs=x) return Conditional(linop, noise) def _select(self, x, /, idx_or_slice): @@ -413,7 +338,9 @@ def A(x): bias = np.zeros(self.ode_shape) eye = np.eye(1) noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: A(s)) + + m = np.zeros((self.num_derivatives + 1,)) + linop = _jac_materialize(lambda s, _p: A(s), inputs=m) return Conditional(linop, noise) @@ -508,12 +435,19 @@ def preconditioner_apply(self, cond, p, p_inv, /): def to_derivative(self, i, standard_deviation): def A(x): - return x[:, [i], ...] + return x[[i], ...] + + @functools.vmap + def lo(y): + return _jac_materialize(lambda s, _p: A(s), inputs=y) + + x = np.zeros((*self.ode_shape, self.num_derivatives + 1)) + linop = lo(x) bias = np.zeros((*self.ode_shape, 1)) eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...] noise = _normal.Normal(bias, standard_deviation * eye) - linop = linop_util.parametrised_linop(lambda s, _p: A(s)) + return Conditional(linop, noise) @@ -560,3 +494,7 @@ def _batch_gram(k, /): def _binom(n, k): return np.factorial(n) / (np.factorial(n - k) * np.factorial(k)) + + +def _jac_materialize(func, /, *, inputs, params=None): + return functools.jacrev(lambda v: func(v, params))(inputs) diff --git a/probdiffeq/impl/_linearise.py b/probdiffeq/impl/_linearise.py index 9b5e38e0..03399d29 100644 --- a/probdiffeq/impl/_linearise.py +++ b/probdiffeq/impl/_linearise.py @@ -1,7 +1,7 @@ from probdiffeq.backend import abc, functools from probdiffeq.backend import numpy as np from probdiffeq.impl import _normal -from probdiffeq.util import cholesky_util, linop_util +from probdiffeq.util import cholesky_util class LinearisationBackend(abc.ABC): @@ -37,9 +37,8 @@ def linearise_fun_wrapped(fun, mean): raise ValueError(msg) fx = self.ts0(fun, a0(mean)) - - linop = linop_util.parametrised_linop( - lambda v, _p: self._autobatch_linop(a1)(v) + linop = _jac_materialize( + lambda v, _p: self._autobatch_linop(a1)(v), inputs=mean ) return linop, -fx @@ -62,7 +61,7 @@ def A(x): x0 = a0(x) return x1 - jvp(x0) - linop = linop_util.parametrised_linop(lambda v, _p: A(v)) + linop = _jac_materialize(lambda v, _p: A(v), inputs=mean) return linop, -fx return new @@ -91,7 +90,7 @@ def new(fun, rv, /): def A(x): return a1(x) - J @ a0(x) - linop = linop_util.parametrised_linop(lambda v, _p: A(v)) + linop = _jac_materialize(lambda v, _p: A(v), inputs=rv.mean) mean, cov_lower = noise.mean, noise.cholesky bias = _normal.Normal(-mean, cov_lower) @@ -121,7 +120,7 @@ def new(fun, rv, /): noise = linearise_fun(fun, linearisation_pt) mean, cov_lower = noise.mean, noise.cholesky bias = _normal.Normal(-mean, cov_lower) - linop = linop_util.parametrised_linop(lambda v, _p: a1(v)) + linop = _jac_materialize(lambda v, _p: a1(v), inputs=rv.mean) return linop, bias return new @@ -202,7 +201,9 @@ def ode_taylor_1st(self, ode_order): def ode_taylor_0th(self, ode_order): def linearise_fun_wrapped(fun, mean): fx = self.ts0(fun, mean[:ode_order, ...]) - linop = linop_util.parametrised_linop(lambda s, _p: s[[ode_order], ...]) + linop = _jac_materialize( + lambda s, _p: s[[ode_order], ...], inputs=mean[:, 0] + ) return linop, -fx return linearise_fun_wrapped @@ -225,9 +226,14 @@ def linearise_fun_wrapped(fun, mean): fx = self.ts0(fun, m0.T) def a1(s): - return s[:, [ode_order], ...] + return s[[ode_order], ...] + + @functools.vmap + def lo(s): + return _jac_materialize(lambda v, _p: a1(v), inputs=s) - return linop_util.parametrised_linop(lambda v, _p: a1(v)), -fx[:, None] + linop = lo(mean) + return linop, -fx[:, None] return linearise_fun_wrapped @@ -243,3 +249,7 @@ def ode_statistical_1st(self, cubature_fun): @staticmethod def ts0(fn, m): return fn(m) + + +def _jac_materialize(func, /, *, inputs, params=None): + return functools.jacrev(lambda v: func(v, params))(inputs) diff --git a/probdiffeq/util/linop_util.py b/probdiffeq/util/linop_util.py deleted file mode 100644 index 630bb55c..00000000 --- a/probdiffeq/util/linop_util.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Matrix-free API.""" - -from probdiffeq.backend import containers, tree_util -from probdiffeq.backend.typing import Any, Callable - - -def parametrised_linop(func, /, params=None): - return CallableLinOp(func=func, params=params) - - -@containers.dataclass -class CallableLinOp: - """Matrix-free linear operator.""" - - func: Callable - params: Any - - def __matmul__(self, other): - return self.func(other, self.params) - - -def _linop_flatten(linop): - children = (linop.params,) - aux = (linop.func,) - return children, aux - - -def _linop_unflatten(aux, children): - (func,) = aux - (params,) = children - return parametrised_linop(func, params=params) - - -tree_util.register_pytree_node(CallableLinOp, _linop_flatten, _linop_unflatten)