Skip to content

Commit

Permalink
Regroup the content of the taylor-subpackage in a taylor-module (#760)
Browse files Browse the repository at this point in the history
* Collect all taylor-code in probdiffeq.taylor

* Delete the (now outdated) taylor package

* Rename taylor.* functions to express *what* they do, not *how* they do it

* Wrap a diffeqzoo.backend.select to avoid unnecessary ci-crashes
  • Loading branch information
pnkraemer authored Jun 13, 2024
1 parent 710b856 commit 13c41fd
Show file tree
Hide file tree
Showing 44 changed files with 200 additions and 235 deletions.
1 change: 1 addition & 0 deletions docs/api_docs/taylor.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: probdiffeq.taylor
1 change: 0 additions & 1 deletion docs/api_docs/taylor/affine.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/api_docs/taylor/autodiff.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/api_docs/taylor/estim.md

This file was deleted.

5 changes: 2 additions & 3 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
import scipy.integrate
import tqdm

from probdiffeq import ivpsolve, ivpsolvers
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -102,7 +101,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
init = solver.initial_condition(tcoeffs, output_scale=1.0)

# Solve
Expand Down
5 changes: 2 additions & 3 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import scipy.integrate
import tqdm

from probdiffeq import ivpsolve, ivpsolvers
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -90,7 +89,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
output_scale = 1.0 * jnp.ones((2,)) if implementation == "blockdiag" else 1.0
init = solver.initial_condition(tcoeffs, output_scale=output_scale)

Expand Down
5 changes: 2 additions & 3 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import scipy.integrate
import tqdm

from probdiffeq import ivpsolve, ivpsolvers
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -114,7 +113,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
init = solver.initial_condition(tcoeffs, output_scale=1.0)

# Solve
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import jax
import jax.numpy as jnp

from probdiffeq import taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -63,7 +63,7 @@ def taylor_mode_scan() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand All @@ -75,7 +75,7 @@ def taylor_mode_unroll() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0,), num=num)
tcoeffs = taylor.odejet_unroll(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand All @@ -87,19 +87,19 @@ def taylor_mode_doubling() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num_doublings=num)
tcoeffs = taylor.odejet_doubling_unroll(vf_auto, (u0,), num_doublings=num)
return jax.block_until_ready(tcoeffs)

return estimate


def forward_mode_recursive() -> Callable:
def odejet_via_jvp() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num)
tcoeffs = taylor.odejet_via_jvp(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -153,7 +153,7 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
if __name__ == "__main__":
set_jax_config()
algorithms = {
r"Forward-mode": forward_mode_recursive(),
r"Forward-mode": odejet_via_jvp(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
Expand Down
19 changes: 11 additions & 8 deletions docs/benchmarks/taylor_node/run_taylor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import jax.numpy as jnp
from diffeqzoo import backend

from probdiffeq import taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -64,7 +64,7 @@ def taylor_mode_scan() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand All @@ -76,7 +76,7 @@ def taylor_mode_unroll() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0,), num=num)
tcoeffs = taylor.odejet_unroll(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand All @@ -88,19 +88,19 @@ def taylor_mode_doubling() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num_doublings=num)
tcoeffs = taylor.odejet_doubling_unroll(vf_auto, (u0,), num_doublings=num)
return jax.block_until_ready(tcoeffs)

return estimate


def forward_mode_recursive() -> Callable:
def odejet_via_jvp() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num)
tcoeffs = taylor.odejet_via_jvp(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -165,9 +165,12 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:

if __name__ == "__main__":
set_jax_config()
backend.select("jax")

if not backend.has_been_selected:
backend.select("jax")

algorithms = {
r"Forward-mode": forward_mode_recursive(),
r"Forward-mode": odejet_via_jvp(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
Expand Down
12 changes: 6 additions & 6 deletions docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import jax
import jax.numpy as jnp

from probdiffeq import taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -63,7 +63,7 @@ def taylor_mode_scan() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand All @@ -75,19 +75,19 @@ def taylor_mode_unroll() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0, du0), num=num)
tcoeffs = taylor.odejet_unroll(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def forward_mode_recursive() -> Callable:
def odejet_via_jvp() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0, du0) = _pleiades()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0, du0), num=num)
tcoeffs = taylor.odejet_via_jvp(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -161,7 +161,7 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
if __name__ == "__main__":
set_jax_config()
algorithms = {
r"Forward-mode": forward_mode_recursive(),
r"Forward-mode": odejet_via_jvp(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
}
Expand Down
5 changes: 2 additions & 3 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
import scipy.integrate
import tqdm

from probdiffeq import ivpsolve, ivpsolvers
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -93,7 +92,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
init = solver.initial_condition(tcoeffs, output_scale=1.0)

# Solve
Expand Down
2 changes: 1 addition & 1 deletion docs/dev_docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
**Breaking changes:**

* What was formerly `taylor_mode()`, is now `taylor_mode_scan()` and stands in contrast to the new `taylor_mode_unroll()`.
* What was formerly `forward_mode()`, is now `forward_mode_recursive()`.
* What was formerly `forward_mode()`, is now `odejet_via_jvp()`.
* The entire `taylor` subpackage moved to top-level. Instead of `from probdiffeq.solvers.taylor import ...`, use `from probdiffeq.taylor import ...`.


Expand Down
5 changes: 2 additions & 3 deletions docs/examples_misc/use_equinox_bounded_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
import jax
import jax.numpy as jnp

from probdiffeq import ivpsolve, ivpsolvers
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.backend import control_flow
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff

jax.config.update("jax_platform_name", "cpu")
impl.select("dense", ode_shape=(1,))
Expand Down Expand Up @@ -70,7 +69,7 @@ def vf(y, *, t): # noqa: ARG001
solver = ivpsolvers.solver(strategy)
adaptive_solver = ivpsolve.adaptive(solver)

tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=1)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init = solver.initial_condition(tcoeffs, 1.0)

def simulate(init_val):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps

from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

# +
Expand Down Expand Up @@ -194,7 +193,7 @@ def solve_fixed(theta, *, ts):
strategy = ivpsolvers.strategy_filter(ibm, ts0)
solver = ivpsolvers.solver(strategy)

tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
output_scale = 10.0
init = solver.initial_condition(tcoeffs, output_scale)

Expand All @@ -212,7 +211,7 @@ def solve_adaptive(theta, *, save_at):
solver = ivpsolvers.solver(strategy)
adaptive_solver = ivpsolve.adaptive(solver)

tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
output_scale = 10.0
init = solver.initial_condition(tcoeffs, output_scale)
return ivpsolve.solve_adaptive_save_at(
Expand Down
5 changes: 2 additions & 3 deletions docs/examples_quickstart/easy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import jax
import jax.numpy as jnp

from probdiffeq import ivpsolve, ivpsolvers
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff

jax.config.update("jax_platform_name", "cpu")

Expand Down Expand Up @@ -117,7 +116,7 @@ def vf(y, *, t): # noqa: ARG001
#
# Use the following functions:

tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4)
output_scale = 1.0 # or any other value with the same shape
init = solver.initial_condition(tcoeffs, output_scale)

Expand Down
5 changes: 2 additions & 3 deletions docs/examples_solver_config/conditioning-on-zero-residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
import matplotlib.pyplot as plt
from diffeqzoo import backend

from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

# -
Expand Down Expand Up @@ -71,7 +70,7 @@ def vector_field(y, t): # noqa: ARG001
markov_seq_prior = stats.MarkovSeq(init_raw, transitions)


tcoeffs = autodiff.taylor_mode_scan(
tcoeffs = taylor.odejet_padded_scan(
lambda y: vector_field(y, t=t0), (u0,), num=NUM_DERIVATIVES
)
init_tcoeffs = impl.ssm_util.normal_from_tcoeffs(
Expand Down
5 changes: 2 additions & 3 deletions docs/examples_solver_config/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps

from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.impl import impl
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

# -
Expand Down Expand Up @@ -72,7 +71,7 @@ def vf(*ys, t): # noqa: ARG001
# +
dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), (u0,))

tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4)
init = solver.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver
Expand Down
Loading

0 comments on commit 13c41fd

Please sign in to comment.