Skip to content

Commit

Permalink
Implement backend.functools.linearize (#712)
Browse files Browse the repository at this point in the history
* Implement backend.functools.linearize

* Variable number of arguments in linearize
  • Loading branch information
pnkraemer authored Dec 15, 2023
1 parent a3566d2 commit 0a9c553
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
4 changes: 4 additions & 0 deletions probdiffeq/backend/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ def jit(func, /, static_argnums=None, static_argnames=None):

def jet(func, /, primals, series):
return jax.experimental.jet.jet(func, primals=primals, series=series)


def linearize(func, *args):
return jax.linearize(func, *args)
3 changes: 1 addition & 2 deletions probdiffeq/impl/dense/_linearise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Linearisation."""

import jax

from probdiffeq.backend import functools
from probdiffeq.backend import numpy as np
Expand Down Expand Up @@ -121,7 +120,7 @@ def ts0(fn, m):


def ts1(fn, m):
b, jvp = jax.linearize(fn, m)
b, jvp = functools.linearize(fn, m)
return jvp, b - jvp(m)


Expand Down
5 changes: 2 additions & 3 deletions probdiffeq/taylor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from typing import Callable

import jax

from probdiffeq.backend import functools
from probdiffeq.backend.typing import Array


Expand All @@ -17,7 +16,7 @@ def affine_recursion(vf: Callable, initial_values: tuple[Array, ...], /, num: in
if num == 0:
return initial_values

fx, jvp_fn = jax.linearize(vf, *initial_values)
fx, jvp_fn = functools.linearize(vf, *initial_values)

tmp = fx
fx_evaluations = [tmp := jvp_fn(tmp) for _ in range(num - 1)]
Expand Down
2 changes: 1 addition & 1 deletion probdiffeq/taylor/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def jet_embedded(*c, degree):
degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings))))
for deg in degrees:
jet_embedded_deg = tree_util.Partial(jet_embedded, degree=deg)
fx, jvp = jax.linearize(jet_embedded_deg, *taylor_coefficients)
fx, jvp = functools.linearize(jet_embedded_deg, *taylor_coefficients)

# Compute the next set of coefficients.
# TODO: can we jax.fori_loop() this loop?
Expand Down

0 comments on commit 0a9c553

Please sign in to comment.