Skip to content

Commit

Permalink
Replace probdiffeq.util.linop_util with jax.Arrays to improve hackabi…
Browse files Browse the repository at this point in the history
…lity (#787)

* All linop_util.LinOp's are jax.Arrays now

* Delete linop_utils.py because this single line of code is more readable if repeated in each module

* Rename parametrise_linop to jac_materialize because it better represents what the function does
  • Loading branch information
pnkraemer authored Oct 25, 2024
1 parent 15f8298 commit 31f027a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 129 deletions.
108 changes: 23 additions & 85 deletions probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any, Array
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, linop_util
from probdiffeq.util import cholesky_util


class TransformBackend(abc.ABC):
Expand Down Expand Up @@ -80,7 +80,8 @@ def marginalise(self, rv, transformation, /):

A_cholesky = A @ cholesky
cholesky = functools.vmap(cholesky_util.triu_via_qr)(_transpose(A_cholesky))
mean = A @ mean + b

mean = functools.vmap(lambda x, y, z: x @ y + z)(A, mean, b)
return _normal.Normal(mean, cholesky)

def revert(self, rv, transformation, /):
Expand All @@ -94,7 +95,7 @@ def revert(self, rv, transformation, /):
cholesky_cor = _transpose(r_cor)

# Gather terms and return
mean_observed = (A @ rv.mean) + bias
mean_observed = functools.vmap(lambda x, y, z: x @ y + z)(A, rv.mean, bias)
m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0]
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
Expand Down Expand Up @@ -145,82 +146,6 @@ def to_derivative(self, i, standard_deviation):
raise NotImplementedError


class ScalarConditional(ConditionalBackend):
def marginalise(self, rv, conditional, /):
matrix, noise = conditional

mean = matrix @ rv.mean + noise.mean
R_stack = ((matrix @ rv.cholesky).T, noise.cholesky.T)
cholesky_T = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack)
return _normal.Normal(mean, cholesky_T.T)

def revert(self, rv, conditional, /):
matrix, noise = conditional

r_ext, (r_bw_p, g_bw_p) = cholesky_util.revert_conditional(
R_X_F=(matrix @ rv.cholesky).T, R_X=rv.cholesky.T, R_YX=noise.cholesky.T
)
m_ext = matrix @ rv.mean + noise.mean
m_cond = rv.mean - g_bw_p @ m_ext

marginal = _normal.Normal(m_ext, r_ext.T)
noise = _normal.Normal(m_cond, r_bw_p.T)
return marginal, Conditional(g_bw_p, noise)

def apply(self, x, conditional, /):
matrix, noise = conditional
matrix = np.squeeze(matrix)
return _normal.Normal(linalg.vector_dot(matrix, x) + noise.mean, noise.cholesky)

def merge(self, previous, incoming, /):
A, b = previous
C, d = incoming

g = A @ C
xi = A @ d.mean + b.mean
R_stack = ((A @ d.cholesky).T, b.cholesky.T)
Xi = cholesky_util.sum_of_sqrtm_factors(R_stack=R_stack).T

noise = _normal.Normal(xi, Xi)
return Conditional(g, noise)

def identity(self, ndim, /) -> Conditional:
transition = np.eye(ndim)
mean = np.zeros((ndim,))
cov_sqrtm = np.zeros((ndim, ndim))
noise = _normal.Normal(mean, cov_sqrtm)
return Conditional(transition, noise)

def ibm_transitions(self, num_derivatives, output_scale=1.0):
a, q_sqrtm = system_matrices_1d(num_derivatives, output_scale)
q0 = np.zeros((num_derivatives + 1,))
noise = _normal.Normal(q0, q_sqrtm)

precon_fun = preconditioner_prepare(num_derivatives=num_derivatives)

def discretise(dt):
p, p_inv = precon_fun(dt)
return Conditional(a, noise), (p, p_inv)

return discretise

def preconditioner_apply(self, cond, p, p_inv, /):
A, noise = cond
A = p[:, None] * A * p_inv[None, :]
noise = _normal.Normal(p * noise.mean, p[:, None] * noise.cholesky)
return Conditional(A, noise)

def to_derivative(self, i, standard_deviation):
def A(x):
return x[[i], ...]

bias = np.zeros(())
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))
return Conditional(linop, noise)


class DenseConditional(ConditionalBackend):
def __init__(self, ode_shape, num_derivatives, unravel):
self.ode_shape = ode_shape
Expand Down Expand Up @@ -311,9 +236,9 @@ def to_derivative(self, i, standard_deviation):
bias = np.zeros((d,))
eye = np.eye(d)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(
lambda s, _p: self._autobatch_linop(a0)(s)
)

x = np.zeros(((self.num_derivatives + 1) * d,))
linop = _jac_materialize(lambda s, _p: self._autobatch_linop(a0)(s), inputs=x)
return Conditional(linop, noise)

def _select(self, x, /, idx_or_slice):
Expand Down Expand Up @@ -413,7 +338,9 @@ def A(x):
bias = np.zeros(self.ode_shape)
eye = np.eye(1)
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))

m = np.zeros((self.num_derivatives + 1,))
linop = _jac_materialize(lambda s, _p: A(s), inputs=m)
return Conditional(linop, noise)


Expand Down Expand Up @@ -508,12 +435,19 @@ def preconditioner_apply(self, cond, p, p_inv, /):

def to_derivative(self, i, standard_deviation):
def A(x):
return x[:, [i], ...]
return x[[i], ...]

@functools.vmap
def lo(y):
return _jac_materialize(lambda s, _p: A(s), inputs=y)

x = np.zeros((*self.ode_shape, self.num_derivatives + 1))
linop = lo(x)

bias = np.zeros((*self.ode_shape, 1))
eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...]
noise = _normal.Normal(bias, standard_deviation * eye)
linop = linop_util.parametrised_linop(lambda s, _p: A(s))

return Conditional(linop, noise)


Expand Down Expand Up @@ -560,3 +494,7 @@ def _batch_gram(k, /):

def _binom(n, k):
return np.factorial(n) / (np.factorial(n - k) * np.factorial(k))


def _jac_materialize(func, /, *, inputs, params=None):
return functools.jacrev(lambda v: func(v, params))(inputs)
30 changes: 20 additions & 10 deletions probdiffeq/impl/_linearise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from probdiffeq.backend import abc, functools
from probdiffeq.backend import numpy as np
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, linop_util
from probdiffeq.util import cholesky_util


class LinearisationBackend(abc.ABC):
Expand Down Expand Up @@ -37,9 +37,8 @@ def linearise_fun_wrapped(fun, mean):
raise ValueError(msg)

fx = self.ts0(fun, a0(mean))

linop = linop_util.parametrised_linop(
lambda v, _p: self._autobatch_linop(a1)(v)
linop = _jac_materialize(
lambda v, _p: self._autobatch_linop(a1)(v), inputs=mean
)
return linop, -fx

Expand All @@ -62,7 +61,7 @@ def A(x):
x0 = a0(x)
return x1 - jvp(x0)

linop = linop_util.parametrised_linop(lambda v, _p: A(v))
linop = _jac_materialize(lambda v, _p: A(v), inputs=mean)
return linop, -fx

return new
Expand Down Expand Up @@ -91,7 +90,7 @@ def new(fun, rv, /):
def A(x):
return a1(x) - J @ a0(x)

linop = linop_util.parametrised_linop(lambda v, _p: A(v))
linop = _jac_materialize(lambda v, _p: A(v), inputs=rv.mean)

mean, cov_lower = noise.mean, noise.cholesky
bias = _normal.Normal(-mean, cov_lower)
Expand Down Expand Up @@ -121,7 +120,7 @@ def new(fun, rv, /):
noise = linearise_fun(fun, linearisation_pt)
mean, cov_lower = noise.mean, noise.cholesky
bias = _normal.Normal(-mean, cov_lower)
linop = linop_util.parametrised_linop(lambda v, _p: a1(v))
linop = _jac_materialize(lambda v, _p: a1(v), inputs=rv.mean)
return linop, bias

return new
Expand Down Expand Up @@ -202,7 +201,9 @@ def ode_taylor_1st(self, ode_order):
def ode_taylor_0th(self, ode_order):
def linearise_fun_wrapped(fun, mean):
fx = self.ts0(fun, mean[:ode_order, ...])
linop = linop_util.parametrised_linop(lambda s, _p: s[[ode_order], ...])
linop = _jac_materialize(
lambda s, _p: s[[ode_order], ...], inputs=mean[:, 0]
)
return linop, -fx

return linearise_fun_wrapped
Expand All @@ -225,9 +226,14 @@ def linearise_fun_wrapped(fun, mean):
fx = self.ts0(fun, m0.T)

def a1(s):
return s[:, [ode_order], ...]
return s[[ode_order], ...]

@functools.vmap
def lo(s):
return _jac_materialize(lambda v, _p: a1(v), inputs=s)

return linop_util.parametrised_linop(lambda v, _p: a1(v)), -fx[:, None]
linop = lo(mean)
return linop, -fx[:, None]

return linearise_fun_wrapped

Expand All @@ -243,3 +249,7 @@ def ode_statistical_1st(self, cubature_fun):
@staticmethod
def ts0(fn, m):
return fn(m)


def _jac_materialize(func, /, *, inputs, params=None):
return functools.jacrev(lambda v: func(v, params))(inputs)
34 changes: 0 additions & 34 deletions probdiffeq/util/linop_util.py

This file was deleted.

0 comments on commit 31f027a

Please sign in to comment.