From 0a9c553b06743f4da9598bd361538e00757c5e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Fri, 15 Dec 2023 11:12:54 +0100 Subject: [PATCH] Implement backend.functools.linearize (#712) * Implement backend.functools.linearize * Variable number of arguments in linearize --- probdiffeq/backend/functools.py | 4 ++++ probdiffeq/impl/dense/_linearise.py | 3 +-- probdiffeq/taylor/affine.py | 5 ++--- probdiffeq/taylor/autodiff.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/probdiffeq/backend/functools.py b/probdiffeq/backend/functools.py index be983a01..5205bf4c 100644 --- a/probdiffeq/backend/functools.py +++ b/probdiffeq/backend/functools.py @@ -19,3 +19,7 @@ def jit(func, /, static_argnums=None, static_argnames=None): def jet(func, /, primals, series): return jax.experimental.jet.jet(func, primals=primals, series=series) + + +def linearize(func, *args): + return jax.linearize(func, *args) diff --git a/probdiffeq/impl/dense/_linearise.py b/probdiffeq/impl/dense/_linearise.py index 5b7dc102..d48f74f9 100644 --- a/probdiffeq/impl/dense/_linearise.py +++ b/probdiffeq/impl/dense/_linearise.py @@ -1,6 +1,5 @@ """Linearisation.""" -import jax from probdiffeq.backend import functools from probdiffeq.backend import numpy as np @@ -121,7 +120,7 @@ def ts0(fn, m): def ts1(fn, m): - b, jvp = jax.linearize(fn, m) + b, jvp = functools.linearize(fn, m) return jvp, b - jvp(m) diff --git a/probdiffeq/taylor/affine.py b/probdiffeq/taylor/affine.py index 7a7c5a12..631cc92f 100644 --- a/probdiffeq/taylor/affine.py +++ b/probdiffeq/taylor/affine.py @@ -2,8 +2,7 @@ from typing import Callable -import jax - +from probdiffeq.backend import functools from probdiffeq.backend.typing import Array @@ -17,7 +16,7 @@ def affine_recursion(vf: Callable, initial_values: tuple[Array, ...], /, num: in if num == 0: return initial_values - fx, jvp_fn = jax.linearize(vf, *initial_values) + fx, jvp_fn = functools.linearize(vf, *initial_values) tmp = fx fx_evaluations = [tmp := jvp_fn(tmp) for _ in range(num - 1)] diff --git a/probdiffeq/taylor/autodiff.py b/probdiffeq/taylor/autodiff.py index c2635cf6..da20d20e 100644 --- a/probdiffeq/taylor/autodiff.py +++ b/probdiffeq/taylor/autodiff.py @@ -179,7 +179,7 @@ def jet_embedded(*c, degree): degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings)))) for deg in degrees: jet_embedded_deg = tree_util.Partial(jet_embedded, degree=deg) - fx, jvp = jax.linearize(jet_embedded_deg, *taylor_coefficients) + fx, jvp = functools.linearize(jet_embedded_deg, *taylor_coefficients) # Compute the next set of coefficients. # TODO: can we jax.fori_loop() this loop?