From 13c41fd60d3467338ae1b13426f803902fbf44f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 13 Jun 2024 11:28:41 +0200 Subject: [PATCH] Regroup the content of the taylor-subpackage in a taylor-module (#760) * Collect all taylor-code in probdiffeq.taylor * Delete the (now outdated) taylor package * Rename taylor.* functions to express *what* they do, not *how* they do it * Wrap a diffeqzoo.backend.select to avoid unnecessary ci-crashes --- docs/api_docs/taylor.md | 1 + docs/api_docs/taylor/affine.md | 1 - docs/api_docs/taylor/autodiff.md | 1 - docs/api_docs/taylor/estim.md | 1 - docs/benchmarks/hires/run_hires.py | 5 +- .../lotkavolterra/run_lotkavolterra.py | 5 +- docs/benchmarks/pleiades/run_pleiades.py | 5 +- .../run_taylor_fitzhughnagumo.py | 14 +-- .../benchmarks/taylor_node/run_taylor_node.py | 19 ++-- .../taylor_pleiades/run_taylor_pleiades.py | 12 +- docs/benchmarks/vanderpol/run_vanderpol.py | 5 +- docs/dev_docs/changelog.md | 2 +- .../use_equinox_bounded_while_loop.py | 5 +- .../physics_enhanced_regression_2.py | 7 +- docs/examples_quickstart/easy_example.py | 5 +- .../conditioning-on-zero-residual.py | 5 +- .../posterior_uncertainties.py | 5 +- .../second_order_problems.py | 7 +- .../transitioning_from_other_packages.md | 4 +- docs/getting_started/troubleshooting.md | 6 +- mkdocs.yml | 5 +- probdiffeq/{taylor/autodiff.py => taylor.py} | 104 ++++++++++++++++-- probdiffeq/taylor/__init__.py | 1 - probdiffeq/taylor/affine.py | 21 ---- probdiffeq/taylor/estim.py | 69 ------------ .../test_fixed_grid_vs_save_every_step.py | 5 +- .../test_save_at_vs_save_every_step.py | 5 +- tests/test_ivpsolve/test_save_every_step.py | 5 +- tests/test_ivpsolve/test_solution_object.py | 5 +- ...test_terminal_values_vs_save_every_step.py | 5 +- ...test_mle_calibration_vs_calibrationfree.py | 11 +- .../test_warnings_for_wrong_strategies.py | 7 +- .../test_log_marginal_likelihood.py | 7 +- ...log_marginal_likelihood_terminal_values.py | 5 +- .../test_stats/test_offgrid_marginals.py | 5 +- tests/test_solvers/test_stats/test_sample.py | 5 +- .../test_rmse_of_correction.py | 5 +- .../test_filter_vs_smoother_rmse.py | 5 +- ...test_smoother_vs_fixedpoint_equivalence.py | 5 +- .../data/generate_reference_solutions.py | 8 +- tests/test_taylor/test_affine_recursion.py | 6 +- tests/test_taylor/test_exact_first_order.py | 12 +- tests/test_taylor/test_exact_higher_order.py | 8 +- tests/test_taylor/test_inexact_first_order.py | 6 +- 44 files changed, 200 insertions(+), 235 deletions(-) create mode 100644 docs/api_docs/taylor.md delete mode 100644 docs/api_docs/taylor/affine.md delete mode 100644 docs/api_docs/taylor/autodiff.md delete mode 100644 docs/api_docs/taylor/estim.md rename probdiffeq/{taylor/autodiff.py => taylor.py} (66%) delete mode 100644 probdiffeq/taylor/__init__.py delete mode 100644 probdiffeq/taylor/affine.py delete mode 100644 probdiffeq/taylor/estim.py diff --git a/docs/api_docs/taylor.md b/docs/api_docs/taylor.md new file mode 100644 index 00000000..a6e36b15 --- /dev/null +++ b/docs/api_docs/taylor.md @@ -0,0 +1 @@ +::: probdiffeq.taylor diff --git a/docs/api_docs/taylor/affine.md b/docs/api_docs/taylor/affine.md deleted file mode 100644 index 79068433..00000000 --- a/docs/api_docs/taylor/affine.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.taylor.affine diff --git a/docs/api_docs/taylor/autodiff.md b/docs/api_docs/taylor/autodiff.md deleted file mode 100644 index 20e65aff..00000000 --- a/docs/api_docs/taylor/autodiff.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.taylor.autodiff diff --git a/docs/api_docs/taylor/estim.md b/docs/api_docs/taylor/estim.md deleted file mode 100644 index 2eafa6da..00000000 --- a/docs/api_docs/taylor/estim.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.taylor.estim diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index 6ddfb9b8..09fd2e02 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -17,9 +17,8 @@ import scipy.integrate import tqdm -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -102,7 +101,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) init = solver.initial_condition(tcoeffs, output_scale=1.0) # Solve diff --git a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py index 87250164..426dc322 100644 --- a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py +++ b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py @@ -18,9 +18,8 @@ import scipy.integrate import tqdm -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -90,7 +89,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) output_scale = 1.0 * jnp.ones((2,)) if implementation == "blockdiag" else 1.0 init = solver.initial_condition(tcoeffs, output_scale=output_scale) diff --git a/docs/benchmarks/pleiades/run_pleiades.py b/docs/benchmarks/pleiades/run_pleiades.py index ec29219b..b0f7421c 100644 --- a/docs/benchmarks/pleiades/run_pleiades.py +++ b/docs/benchmarks/pleiades/run_pleiades.py @@ -18,9 +18,8 @@ import scipy.integrate import tqdm -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -114,7 +113,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1) init = solver.initial_condition(tcoeffs, output_scale=1.0) # Solve diff --git a/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py b/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py index 65c3122c..c5bc3caa 100644 --- a/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py +++ b/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py @@ -14,8 +14,8 @@ import jax import jax.numpy as jnp +from probdiffeq import taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -63,7 +63,7 @@ def taylor_mode_scan() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -75,7 +75,7 @@ def taylor_mode_unroll() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0,), num=num) + tcoeffs = taylor.odejet_unroll(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -87,19 +87,19 @@ def taylor_mode_doubling() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num_doublings=num) + tcoeffs = taylor.odejet_doubling_unroll(vf_auto, (u0,), num_doublings=num) return jax.block_until_ready(tcoeffs) return estimate -def forward_mode_recursive() -> Callable: +def odejet_via_jvp() -> Callable: """Forward-mode estimation.""" vf_auto, (u0,) = _fitzhugh_nagumo() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num) + tcoeffs = taylor.odejet_via_jvp(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -153,7 +153,7 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: if __name__ == "__main__": set_jax_config() algorithms = { - r"Forward-mode": forward_mode_recursive(), + r"Forward-mode": odejet_via_jvp(), r"Taylor-mode (scan)": taylor_mode_scan(), r"Taylor-mode (unroll)": taylor_mode_unroll(), r"Taylor-mode (doubling)": taylor_mode_doubling(), diff --git a/docs/benchmarks/taylor_node/run_taylor_node.py b/docs/benchmarks/taylor_node/run_taylor_node.py index 9c79cf6e..e8233743 100644 --- a/docs/benchmarks/taylor_node/run_taylor_node.py +++ b/docs/benchmarks/taylor_node/run_taylor_node.py @@ -15,8 +15,8 @@ import jax.numpy as jnp from diffeqzoo import backend +from probdiffeq import taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -64,7 +64,7 @@ def taylor_mode_scan() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -76,7 +76,7 @@ def taylor_mode_unroll() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0,), num=num) + tcoeffs = taylor.odejet_unroll(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -88,19 +88,19 @@ def taylor_mode_doubling() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num_doublings=num) + tcoeffs = taylor.odejet_doubling_unroll(vf_auto, (u0,), num_doublings=num) return jax.block_until_ready(tcoeffs) return estimate -def forward_mode_recursive() -> Callable: +def odejet_via_jvp() -> Callable: """Forward-mode estimation.""" vf_auto, (u0,) = _node() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num) + tcoeffs = taylor.odejet_via_jvp(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -165,9 +165,12 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: if __name__ == "__main__": set_jax_config() - backend.select("jax") + + if not backend.has_been_selected: + backend.select("jax") + algorithms = { - r"Forward-mode": forward_mode_recursive(), + r"Forward-mode": odejet_via_jvp(), r"Taylor-mode (scan)": taylor_mode_scan(), r"Taylor-mode (unroll)": taylor_mode_unroll(), r"Taylor-mode (doubling)": taylor_mode_doubling(), diff --git a/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py b/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py index 84ae9e56..82bbaad6 100644 --- a/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py +++ b/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py @@ -14,8 +14,8 @@ import jax import jax.numpy as jnp +from probdiffeq import taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -63,7 +63,7 @@ def taylor_mode_scan() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -75,19 +75,19 @@ def taylor_mode_unroll() -> Callable: @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0, du0), num=num) + tcoeffs = taylor.odejet_unroll(vf_auto, (u0, du0), num=num) return jax.block_until_ready(tcoeffs) return estimate -def forward_mode_recursive() -> Callable: +def odejet_via_jvp() -> Callable: """Forward-mode estimation.""" vf_auto, (u0, du0) = _pleiades() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0, du0), num=num) + tcoeffs = taylor.odejet_via_jvp(vf_auto, (u0, du0), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -161,7 +161,7 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: if __name__ == "__main__": set_jax_config() algorithms = { - r"Forward-mode": forward_mode_recursive(), + r"Forward-mode": odejet_via_jvp(), r"Taylor-mode (scan)": taylor_mode_scan(), r"Taylor-mode (unroll)": taylor_mode_unroll(), } diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index 3a0879fd..84c8b03c 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -17,9 +17,8 @@ import scipy.integrate import tqdm -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -93,7 +92,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1) + tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1) init = solver.initial_condition(tcoeffs, output_scale=1.0) # Solve diff --git a/docs/dev_docs/changelog.md b/docs/dev_docs/changelog.md index 9711b440..735a2ef8 100644 --- a/docs/dev_docs/changelog.md +++ b/docs/dev_docs/changelog.md @@ -9,7 +9,7 @@ **Breaking changes:** * What was formerly `taylor_mode()`, is now `taylor_mode_scan()` and stands in contrast to the new `taylor_mode_unroll()`. -* What was formerly `forward_mode()`, is now `forward_mode_recursive()`. +* What was formerly `forward_mode()`, is now `odejet_via_jvp()`. * The entire `taylor` subpackage moved to top-level. Instead of `from probdiffeq.solvers.taylor import ...`, use `from probdiffeq.taylor import ...`. diff --git a/docs/examples_misc/use_equinox_bounded_while_loop.py b/docs/examples_misc/use_equinox_bounded_while_loop.py index 63c87cb6..f7cd5ecf 100644 --- a/docs/examples_misc/use_equinox_bounded_while_loop.py +++ b/docs/examples_misc/use_equinox_bounded_while_loop.py @@ -24,10 +24,9 @@ import jax import jax.numpy as jnp -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import control_flow from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff jax.config.update("jax_platform_name", "cpu") impl.select("dense", ode_shape=(1,)) @@ -70,7 +69,7 @@ def vf(y, *, t): # noqa: ARG001 solver = ivpsolvers.solver(strategy) adaptive_solver = ivpsolve.adaptive(solver) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=1) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1) init = solver.initial_condition(tcoeffs, 1.0) def simulate(init_val): diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py index 8f44406e..a42874a1 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py @@ -132,9 +132,8 @@ import matplotlib.pyplot as plt from diffeqzoo import backend, ivps -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import notebook # + @@ -194,7 +193,7 @@ def solve_fixed(theta, *, ts): strategy = ivpsolvers.strategy_filter(ibm, ts0) solver = ivpsolvers.solver(strategy) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2) output_scale = 10.0 init = solver.initial_condition(tcoeffs, output_scale) @@ -212,7 +211,7 @@ def solve_adaptive(theta, *, save_at): solver = ivpsolvers.solver(strategy) adaptive_solver = ivpsolve.adaptive(solver) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2) output_scale = 10.0 init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_adaptive_save_at( diff --git a/docs/examples_quickstart/easy_example.py b/docs/examples_quickstart/easy_example.py index a5b4ea6b..7c1bc69e 100644 --- a/docs/examples_quickstart/easy_example.py +++ b/docs/examples_quickstart/easy_example.py @@ -22,9 +22,8 @@ import jax import jax.numpy as jnp -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff jax.config.update("jax_platform_name", "cpu") @@ -117,7 +116,7 @@ def vf(y, *, t): # noqa: ARG001 # # Use the following functions: -tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) +tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) output_scale = 1.0 # or any other value with the same shape init = solver.initial_condition(tcoeffs, output_scale) diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.py b/docs/examples_solver_config/conditioning-on-zero-residual.py index a6ac7e20..e1293e99 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.py +++ b/docs/examples_solver_config/conditioning-on-zero-residual.py @@ -26,9 +26,8 @@ import matplotlib.pyplot as plt from diffeqzoo import backend -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import notebook # - @@ -71,7 +70,7 @@ def vector_field(y, t): # noqa: ARG001 markov_seq_prior = stats.MarkovSeq(init_raw, transitions) -tcoeffs = autodiff.taylor_mode_scan( +tcoeffs = taylor.odejet_padded_scan( lambda y: vector_field(y, t=t0), (u0,), num=NUM_DERIVATIVES ) init_tcoeffs = impl.ssm_util.normal_from_tcoeffs( diff --git a/docs/examples_solver_config/posterior_uncertainties.py b/docs/examples_solver_config/posterior_uncertainties.py index f3d92698..a4f93eef 100644 --- a/docs/examples_solver_config/posterior_uncertainties.py +++ b/docs/examples_solver_config/posterior_uncertainties.py @@ -22,9 +22,8 @@ import matplotlib.pyplot as plt from diffeqzoo import backend, ivps -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import notebook # - @@ -72,7 +71,7 @@ def vf(*ys, t): # noqa: ARG001 # + dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), (u0,)) -tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) +tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) init = solver.initial_condition(tcoeffs, output_scale=1.0) sol = ivpsolve.solve_adaptive_save_at( vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver diff --git a/docs/examples_solver_config/second_order_problems.py b/docs/examples_solver_config/second_order_problems.py index e4cc1a7f..f26558a9 100644 --- a/docs/examples_solver_config/second_order_problems.py +++ b/docs/examples_solver_config/second_order_problems.py @@ -22,9 +22,8 @@ import matplotlib.pyplot as plt from diffeqzoo import backend, ivps -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import notebook # - @@ -58,7 +57,7 @@ def vf_1(y, t): # noqa: ARG001 adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5) -tcoeffs = autodiff.taylor_mode_scan(lambda y: vf_1(y, t=t0), (u0,), num=4) +tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4) init = solver_1st.initial_condition(tcoeffs, output_scale=1.0) # - @@ -92,7 +91,7 @@ def vf_2(y, dy, t): # noqa: ARG001 adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5) -tcoeffs = autodiff.taylor_mode_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) +tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) init = solver_2nd.initial_condition(tcoeffs, output_scale=1.0) # - diff --git a/docs/getting_started/transitioning_from_other_packages.md b/docs/getting_started/transitioning_from_other_packages.md index ec983867..a35df8a1 100644 --- a/docs/getting_started/transitioning_from_other_packages.md +++ b/docs/getting_started/transitioning_from_other_packages.md @@ -39,8 +39,8 @@ ProbDiffEq can reproduce most of the implementations in Tornadox: | `solver.solve()` | `solve_adaptive_save_every_step()` | Try `solve_adaptive_save_at()` instead. | | `solver.simulate_final_state()` | `simulate_terminal_values()` | ProbDiffEq compiles the whole loop; it will be much faster. | | `solver.solution_generator()` | Work in progress. | | -| `init.TaylorMode()` | `taylor.autodiff.taylor_mode` | Consider `taylor.autodiff.forward_mode_recursive()` for low numbers of derivatives and `taylor.autodiff.taylor_mode_doubling()` for (absurdly) high numbers of derivatives | -| `init.RungeKutta()` | `taylor.estim.make_runge_kutta_starter()` | | +| `init.TaylorMode()` | `taylor.taylor.taylor_mode` | Consider `taylor.taylor.odejet_via_jvp()` for low numbers of derivatives and `taylor.taylor.odejet_doubling_unroll()` for (absurdly) high numbers of derivatives | +| `init.RungeKutta()` | `taylor.taylor.runge_kutta_starter()` | | Recently, the development of Tornadox has not been very active, so its API is relatively stable. diff --git a/docs/getting_started/troubleshooting.md b/docs/getting_started/troubleshooting.md index be6e86e8..84220f8f 100644 --- a/docs/getting_started/troubleshooting.md +++ b/docs/getting_started/troubleshooting.md @@ -5,14 +5,14 @@ If a solution routine takes surprisingly long to compile but then executes quickly, it may be due to the choice of Taylor-coefficient computation. Some functions in `probdiffeq.taylor` unroll a (small) loop. -To avoid this, use `probdiffeq.taylor.autodiff.taylor_mode_scan()` +To avoid this, use `probdiffeq.taylor.taylor.odejet_padded_scan()` (which is implemented with a scan). If the problem persists, reduce the number of derivatives (if that is appropriate for your integration problem) or switch to a different Taylor-coefficient routine. -For example, use a Runge-Kutta starter `probdiffeq.taylor.estim.make_runge_kutta_starter()`. +For example, use a Runge-Kutta starter `probdiffeq.taylor.taylor.runge_kutta_starter()`. For $\nu < 5$, switching to Runge-Kutta starters should preserve performance of the solvers. -High-order methods, e.g. $\nu = 9$ seem to rely on `taylor_fn=taylor.taylor_mode_fn`. +High-order methods, e.g. $\nu = 9$ seem to rely on `taylor_fn=taylor.odejet_fn`. ## Other problems diff --git a/mkdocs.yml b/mkdocs.yml index 03234fea..fe7f0746 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,10 +94,7 @@ nav: - ivpsolve: api_docs/ivpsolve.md - ivpsolvers: api_docs/ivpsolvers.md - stats: api_docs/stats.md - - taylor: - - affine: api_docs/taylor/affine.md - - autodiff: api_docs/taylor/autodiff.md - - estim: api_docs/taylor/estim.md + - taylor: api_docs/taylor.md - impl: api_docs/impl.md - DEVELOPER DOCUMENTATION: - Changelog: dev_docs/changelog.md diff --git a/probdiffeq/taylor/autodiff.py b/probdiffeq/taylor.py similarity index 66% rename from probdiffeq/taylor/autodiff.py rename to probdiffeq/taylor.py index 769a844b..bbc0f387 100644 --- a/probdiffeq/taylor/autodiff.py +++ b/probdiffeq/taylor.py @@ -1,16 +1,81 @@ r"""Taylor-expand the solution of an initial value problem (IVP).""" -from probdiffeq.backend import control_flow, functools, itertools, tree_util +from probdiffeq.backend import control_flow, functools, itertools, ode, tree_util from probdiffeq.backend import numpy as np from probdiffeq.backend.typing import Array, Callable +from probdiffeq.impl import impl +from probdiffeq.util import filter_util -def taylor_mode_scan(vf: Callable, inits: tuple[Array, ...], /, num: int): +def runge_kutta_starter(dt, *, atol=1e-12, rtol=1e-10): + """Create an estimator that uses a Runge-Kutta starter.""" + # If the accuracy of the initialisation is bad, play around with dt. + return functools.partial(_runge_kutta_starter, dt0=dt, atol=atol, rtol=rtol) + + +# atol and rtol must be static bc. of jax.odeint... +@functools.partial( + functools.jit, static_argnums=[0], static_argnames=["num", "atol", "rtol"] +) +def _runge_kutta_starter(vf, initial_values, /, num: int, t, dt0, atol, rtol): + # TODO [inaccuracy]: the initial-value uncertainty is discarded + # TODO [feature]: allow implementations other than IsoIBM? + # TODO [feature]: higher-order ODEs + + # Assertions and early exits + + if len(initial_values) > 1: + msg = "Higher-order ODEs are not supported at the moment." + raise ValueError(msg) + + if num == 0: + return initial_values + + if num == 1: + return *initial_values, vf(*initial_values, t) + + # Generate data + + # TODO: allow flexible "solve" method? + k = num + 1 # important: k > num + ts = np.linspace(t, t + dt0 * (k - 1), num=k, endpoint=True) + ys = ode.odeint_and_save_at(vf, initial_values, save_at=ts, atol=atol, rtol=rtol) + + # Initial condition + estimator = filter_util.fixedpointsmoother_precon() + rv_t0 = impl.ssm_util.standard_normal(num + 1, 1.0) + conditional_t0 = impl.ssm_util.identity_conditional(num + 1) + init = (rv_t0, conditional_t0) + + # Discretised prior + discretise = impl.ssm_util.ibm_transitions(num, 1.0) + ibm_transitions = functools.vmap(discretise)(np.diff(ts)) + + # Generate an observation-model for the QOI + # (1e-7 observation noise for nuggets and for reusing existing code) + model_fun = functools.vmap( + impl.hidden_model.conditional_to_derivative, in_axes=(None, 0) + ) + models = model_fun(0, 1e-7 * np.ones_like(ts)) + + # Run the preconditioned fixedpoint smoother + (corrected, conditional), _ = filter_util.estimate_fwd( + ys, + init=init, + prior_transitions=ibm_transitions, + observation_model=models, + estimator=estimator, + ) + initial = impl.conditional.marginalise(corrected, conditional) + return tuple(impl.stats.mean(initial)) + + +def odejet_padded_scan(vf: Callable, inits: tuple[Array, ...], /, num: int): """Taylor-expand the solution of an IVP with Taylor-mode differentiation. - Other than `taylor_mode_unroll()`, this function implements the loop via a scan, + Other than `odejet_unroll()`, this function implements the loop via a scan, which comes at the price of padding the loop variable with zeros as appropriate. - It is expected to compile more quickly than `taylor_mode_unroll()`, but may + It is expected to compile more quickly than `odejet_unroll()`, but may execute more slowly. The differences should be small. @@ -27,7 +92,7 @@ def body(tcoeffs, _): # Pad the Taylor coefficients in zeros, call jet, and return the solution. # This works, because the $i$th output coefficient of jet() # is independent of the $i+j$th input coefficient - # (see also the explanation in taylor_mode_doubling) + # (see also the explanation in odejet_doubling_unroll) series = _subsets(tcoeffs[1:], num_arguments) # for high-order ODEs p, s_new = functools.jet(vf, primals=inits, series=series) @@ -54,12 +119,12 @@ def body(tcoeffs, _): return taylor_coeffs -def taylor_mode_unroll(vf: Callable, inits: tuple[Array, ...], /, num: int): +def odejet_unroll(vf: Callable, inits: tuple[Array, ...], /, num: int): """Taylor-expand the solution of an IVP with Taylor-mode differentiation. - Other than `taylor_mode_scan()`, this function does not depend on zero-padding + Other than `odejet_padded_scan()`, this function does not depend on zero-padding the coefficients at the price of unrolling a loop of length `num-1`. - It is expected to compile more slowly than `taylor_mode_scan()`, + It is expected to compile more slowly than `odejet_padded_scan()`, but execute more quickly. The differences should be small. @@ -109,7 +174,7 @@ def mask(i): return [x[mask(k) : mask(k + 1 - n)] for k in range(n)] -def forward_mode_recursive(vf: Callable, inits: tuple[Array, ...], /, num: int): +def odejet_via_jvp(vf: Callable, inits: tuple[Array, ...], /, num: int): """Taylor-expand the solution of an IVP with recursive forward-mode differentiation. !!! warning "Compilation time" @@ -138,7 +203,9 @@ def df(*args): return tree_util.Partial(df) -def taylor_mode_doubling(vf: Callable, inits: tuple[Array, ...], /, num_doublings: int): +def odejet_doubling_unroll( + vf: Callable, inits: tuple[Array, ...], /, num_doublings: int +): """Combine Taylor-mode differentiation and Newton's doubling. !!! warning "Warning: highly EXPERIMENTAL feature!" @@ -214,3 +281,20 @@ def _unnormalise(primals, *series): """Normalised Taylor series to un-normalised Taylor series.""" series_new = [s * np.factorial(i + 1) for i, s in enumerate(series)] return primals, *series_new + + +def odejet_affine(vf: Callable, initial_values: tuple[Array, ...], /, num: int): + """Evaluate the Taylor series of an affine differential equation. + + !!! warning "Compilation time" + JIT-compiling this function unrolls a loop of length `num`. + + """ + if num == 0: + return initial_values + + fx, jvp_fn = functools.linearize(vf, *initial_values) + + tmp = fx + fx_evaluations = [tmp := jvp_fn(tmp) for _ in range(num - 1)] + return [*initial_values, fx, *fx_evaluations] diff --git a/probdiffeq/taylor/__init__.py b/probdiffeq/taylor/__init__.py deleted file mode 100644 index b45b4ed5..00000000 --- a/probdiffeq/taylor/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Taylor-series estimation.""" diff --git a/probdiffeq/taylor/affine.py b/probdiffeq/taylor/affine.py deleted file mode 100644 index 6524c45e..00000000 --- a/probdiffeq/taylor/affine.py +++ /dev/null @@ -1,21 +0,0 @@ -r"""Taylor-expand the solution of an initial value problem (IVP).""" - -from probdiffeq.backend import functools -from probdiffeq.backend.typing import Array, Callable - - -def affine_recursion(vf: Callable, initial_values: tuple[Array, ...], /, num: int): - """Evaluate the Taylor series of an affine differential equation. - - !!! warning "Compilation time" - JIT-compiling this function unrolls a loop of length `num`. - - """ - if num == 0: - return initial_values - - fx, jvp_fn = functools.linearize(vf, *initial_values) - - tmp = fx - fx_evaluations = [tmp := jvp_fn(tmp) for _ in range(num - 1)] - return [*initial_values, fx, *fx_evaluations] diff --git a/probdiffeq/taylor/estim.py b/probdiffeq/taylor/estim.py deleted file mode 100644 index 6e273bdb..00000000 --- a/probdiffeq/taylor/estim.py +++ /dev/null @@ -1,69 +0,0 @@ -r"""Taylor-expand the solution of an initial value problem (IVP).""" - -from probdiffeq.backend import functools, ode -from probdiffeq.backend import numpy as np -from probdiffeq.impl import impl -from probdiffeq.util import filter_util - - -def make_runge_kutta_starter(dt, *, atol=1e-12, rtol=1e-10): - """Create an estimator that uses a Runge-Kutta starter.""" - # If the accuracy of the initialisation is bad, play around with dt. - return functools.partial(_runge_kutta_starter, dt0=dt, atol=atol, rtol=rtol) - - -# atol and rtol must be static bc. of jax.odeint... -@functools.partial( - functools.jit, static_argnums=[0], static_argnames=["num", "atol", "rtol"] -) -def _runge_kutta_starter(vf, initial_values, /, num: int, t, dt0, atol, rtol): - # TODO [inaccuracy]: the initial-value uncertainty is discarded - # TODO [feature]: allow implementations other than IsoIBM? - # TODO [feature]: higher-order ODEs - - # Assertions and early exits - - if len(initial_values) > 1: - msg = "Higher-order ODEs are not supported at the moment." - raise ValueError(msg) - - if num == 0: - return initial_values - - if num == 1: - return *initial_values, vf(*initial_values, t) - - # Generate data - - # TODO: allow flexible "solve" method? - k = num + 1 # important: k > num - ts = np.linspace(t, t + dt0 * (k - 1), num=k, endpoint=True) - ys = ode.odeint_and_save_at(vf, initial_values, save_at=ts, atol=atol, rtol=rtol) - - # Initial condition - estimator = filter_util.fixedpointsmoother_precon() - rv_t0 = impl.ssm_util.standard_normal(num + 1, 1.0) - conditional_t0 = impl.ssm_util.identity_conditional(num + 1) - init = (rv_t0, conditional_t0) - - # Discretised prior - discretise = impl.ssm_util.ibm_transitions(num, 1.0) - ibm_transitions = functools.vmap(discretise)(np.diff(ts)) - - # Generate an observation-model for the QOI - # (1e-7 observation noise for nuggets and for reusing existing code) - model_fun = functools.vmap( - impl.hidden_model.conditional_to_derivative, in_axes=(None, 0) - ) - models = model_fun(0, 1e-7 * np.ones_like(ts)) - - # Run the preconditioned fixedpoint smoother - (corrected, conditional), _ = filter_util.estimate_fwd( - ys, - init=init, - prior_transitions=ibm_transitions, - observation_model=models, - estimator=estimator, - ) - initial = impl.conditional.marginalise(corrected, conditional) - return tuple(impl.stats.mean(initial)) diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 1832ca20..0a980130 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -1,10 +1,9 @@ """Compare solve_fixed_grid to solve_adaptive_save_every_step.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -18,7 +17,7 @@ def test_fixed_grid_result_matches_adaptive_grid_result(): control = ivpsolve.control_integral_clipped() # Any clipped controller will do. adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, control=control) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) args = (vf, init) diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index fc84a421..bb3ef4ab 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -1,10 +1,9 @@ """Assert that solve_adaptive_save_at is consistent with solve_with_python_loop().""" -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import functools, testing, tree_util from probdiffeq.backend import numpy as np from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -20,7 +19,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(): solver = ivpsolvers.solver(strategy) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) problem_args = (vf, init) diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index dfcc5343..f2d14150 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -1,10 +1,9 @@ """Assert that solve_with_python_loop is accurate.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -22,7 +21,7 @@ def fixture_python_loop_solution(): vf, u0, t0=t0, atol=1e-2, rtol=1e-2, error_contraction_rate=5 ) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) diff --git a/tests/test_ivpsolve/test_solution_object.py b/tests/test_ivpsolve/test_solution_object.py index 80b2eb5d..dfb8daf1 100644 --- a/tests/test_ivpsolve/test_solution_object.py +++ b/tests/test_ivpsolve/test_solution_object.py @@ -1,10 +1,9 @@ """Tests for interaction with the solution object.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import functools, testing from probdiffeq.backend import numpy as np from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -20,7 +19,7 @@ def fixture_approximate_solution(): adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=1) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=1) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_adaptive_save_every_step( diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index b9499199..149bf7bd 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -1,10 +1,9 @@ """Compare simulate_terminal_values to solve_adaptive_save_every_step.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing, tree_util from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -19,7 +18,7 @@ def fixture_problem_args_kwargs(): solver = ivpsolvers.solver_mle(strategy) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) diff --git a/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py b/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py index 424637f0..95bf23d8 100644 --- a/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py +++ b/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py @@ -5,18 +5,17 @@ After applying stats.calibrate(), the posterior is different. """ -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @testing.case() def case_solve_fixed_grid(): vf, u0, (t0, t1) = setup.ode() - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) kwargs = {"grid": np.linspace(t0, t1, endpoint=True, num=5)} @@ -31,7 +30,7 @@ def solver_to_solution(solver): def case_solve_adaptive_save_at(): vf, u0, (t0, t1) = setup.ode() dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), u0) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) kwargs = {"save_at": np.linspace(t0, t1, endpoint=True, num=5), "dt0": dt0} @@ -49,7 +48,7 @@ def solver_to_solution(solver): def case_solve_adaptive_save_every_step(): vf, u0, (t0, t1) = setup.ode() dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), u0) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) kwargs = {"t0": t0, "t1": t1, "dt0": dt0} @@ -67,7 +66,7 @@ def solver_to_solution(solver): def case_simulate_terminal_values(): vf, u0, (t0, t1) = setup.ode() dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), u0) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) kwargs = {"t0": t0, "t1": t1, "dt0": dt0} diff --git a/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py b/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py index acd325e3..664dba73 100644 --- a/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py +++ b/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py @@ -1,10 +1,9 @@ """Some strategies don't work with all solution routines.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -18,7 +17,7 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(): adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) with testing.warns(): @@ -37,7 +36,7 @@ def test_warning_for_smoother_in_save_at_mode(): adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) with testing.warns(): diff --git a/tests/test_solvers/test_stats/test_log_marginal_likelihood.py b/tests/test_solvers/test_stats/test_log_marginal_likelihood.py index 82298d7b..02dfb810 100644 --- a/tests/test_solvers/test_stats/test_log_marginal_likelihood.py +++ b/tests/test_solvers/test_stats/test_log_marginal_likelihood.py @@ -1,10 +1,9 @@ """Tests for log-marginal-likelihood functionality.""" -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing, tree_util from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -19,7 +18,7 @@ def fixture_sol(): adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) save_at = np.linspace(t0, t1, endpoint=True, num=4) @@ -91,7 +90,7 @@ def test_raises_error_for_filter(): solver = ivpsolvers.solver(strategy) grid = np.linspace(t0, t1, num=3) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver) diff --git a/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py b/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py index 7e5da594..4b183911 100644 --- a/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py @@ -1,10 +1,9 @@ """Tests for marginal log likelihood functionality (terminal values).""" -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -34,7 +33,7 @@ def fixture_sol(strategy_func): solver = ivpsolvers.solver(strategy) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_adaptive_terminal_values( diff --git a/tests/test_solvers/test_stats/test_offgrid_marginals.py b/tests/test_solvers/test_stats/test_offgrid_marginals.py index 0ecee319..3662b3f4 100644 --- a/tests/test_solvers/test_stats/test_offgrid_marginals.py +++ b/tests/test_solvers/test_stats/test_offgrid_marginals.py @@ -1,9 +1,8 @@ """Tests for IVP solvers.""" -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import numpy as np from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -40,7 +39,7 @@ def test_smoother_marginals_close_to_both_boundaries(): solver = ivpsolvers.solver(strategy) output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) init = solver.initial_condition(tcoeffs, output_scale) grid = np.linspace(t0, t1, endpoint=True, num=5) sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver) diff --git a/tests/test_solvers/test_stats/test_sample.py b/tests/test_solvers/test_stats/test_sample.py index 06b04766..9a1cc0ed 100644 --- a/tests/test_solvers/test_stats/test_sample.py +++ b/tests/test_solvers/test_stats/test_sample.py @@ -1,10 +1,9 @@ """Tests for sampling behaviour.""" -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import random, testing, tree_util from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -19,7 +18,7 @@ def fixture_approximation(): adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_adaptive_save_every_step( vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1 diff --git a/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py b/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py index 080d982e..73976b7f 100644 --- a/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py +++ b/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py @@ -1,10 +1,9 @@ """Assert that every recipe yields a decent ODE approximation.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -70,7 +69,7 @@ def fixture_solution(correction_impl): adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1} - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_adaptive_terminal_values( diff --git a/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py b/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py index 2bebbdfe..e7acb10b 100644 --- a/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py +++ b/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py @@ -1,10 +1,9 @@ """The RMSE of the smoother should be (slightly) lower than the RMSE of the filter.""" -from probdiffeq import ivpsolve, ivpsolvers +from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import linalg, ode, testing from probdiffeq.backend import numpy as np from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -14,7 +13,7 @@ def fixture_solver_setup(): output_scale = np.ones_like(impl.prototypes.output_scale()) grid = np.linspace(t0, t1, endpoint=True, num=12) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) return {"vf": vf, "tcoeffs": tcoeffs, "grid": grid, "output_scale": output_scale} diff --git a/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py b/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py index e42c0d1c..09c3b458 100644 --- a/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py +++ b/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py @@ -3,11 +3,10 @@ That is, when called with correct adaptive- and checkpoint-setups. """ -from probdiffeq import ivpsolve, ivpsolvers, stats +from probdiffeq import ivpsolve, ivpsolvers, stats, taylor from probdiffeq.backend import functools, testing, tree_util from probdiffeq.backend import numpy as np from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff from tests.setup import setup @@ -17,7 +16,7 @@ def fixture_solver_setup(): output_scale = np.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) return { "vf": vf, "tcoeffs": tcoeffs, diff --git a/tests/test_taylor/data/generate_reference_solutions.py b/tests/test_taylor/data/generate_reference_solutions.py index fe01a3c3..ea3646d9 100644 --- a/tests/test_taylor/data/generate_reference_solutions.py +++ b/tests/test_taylor/data/generate_reference_solutions.py @@ -1,8 +1,8 @@ """Precompute and save reference solutions. Accelerate testing.""" +from probdiffeq import taylor from probdiffeq.backend import config, ode from probdiffeq.backend import numpy as np -from probdiffeq.taylor import autodiff def set_environment(): @@ -20,15 +20,13 @@ def set_environment(): def three_body_first(num_derivatives_max=6): vf, (u0,), (t0, _) = ode.ivp_three_body_1st() - return autodiff.taylor_mode_unroll( - lambda y: vf(y, t=t0), (u0,), num=num_derivatives_max - ) + return taylor.odejet_unroll(lambda y: vf(y, t=t0), (u0,), num=num_derivatives_max) def van_der_pol_second(num_derivatives_max=6): vf, (u0, du0), (t0, _) = ode.ivp_van_der_pol_2nd() - return autodiff.taylor_mode_unroll( + return taylor.odejet_unroll( lambda *ys: vf(*ys, t=t0), (u0, du0), num=num_derivatives_max ) diff --git a/tests/test_taylor/test_affine_recursion.py b/tests/test_taylor/test_affine_recursion.py index af0d66c8..26bfb3c8 100644 --- a/tests/test_taylor/test_affine_recursion.py +++ b/tests/test_taylor/test_affine_recursion.py @@ -1,8 +1,8 @@ """Tests for the affine recursion.""" +from probdiffeq import taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import testing -from probdiffeq.taylor import affine, autodiff @testing.parametrize("num", [1, 2, 4]) @@ -10,7 +10,7 @@ def test_affine_recursion(num, num_derivatives_max=5): """The approximation should coincide with the reference.""" f, init, solution = _affine_problem(num_derivatives_max) - derivatives = affine.affine_recursion(f, init, num=num) + derivatives = taylor.odejet_affine(f, init, num=num) # check shape assert len(derivatives) == len(init) + num @@ -29,5 +29,5 @@ def vf(x, /): init = (np.arange(9.0, 11.0),) - solution = autodiff.taylor_mode_scan(vf, init, num=n) + solution = taylor.odejet_padded_scan(vf, init, num=n) return vf, init, solution diff --git a/tests/test_taylor/test_exact_first_order.py b/tests/test_taylor/test_exact_first_order.py index 4170258e..7eaafc27 100644 --- a/tests/test_taylor/test_exact_first_order.py +++ b/tests/test_taylor/test_exact_first_order.py @@ -1,23 +1,23 @@ """Test the exactness of differentiation-based routines on first-order problems.""" +from probdiffeq import taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing -from probdiffeq.taylor import autodiff @testing.case() -def case_forward_mode_recursive(): - return autodiff.forward_mode_recursive +def case_odejet_via_jvp(): + return taylor.odejet_via_jvp @testing.case() def case_taylor_mode_scan(): - return autodiff.taylor_mode_scan + return taylor.odejet_padded_scan @testing.case() def case_taylor_mode_unroll(): - return autodiff.taylor_mode_unroll + return taylor.odejet_unroll @testing.fixture(name="pb_with_solution") @@ -44,7 +44,7 @@ def test_approximation_identical_to_reference_doubling(pb_with_solution, num_dou """Separately test the doubling-function, because its API is different.""" (f, init), solution = pb_with_solution - derivatives = autodiff.taylor_mode_doubling(f, init, num_doublings=num_doublings) + derivatives = taylor.odejet_doubling_unroll(f, init, num_doublings=num_doublings) assert len(derivatives) == np.sum(2 ** np.arange(0, num_doublings + 1)) for dy, dy_ref in zip(derivatives, solution): assert np.allclose(dy, dy_ref) diff --git a/tests/test_taylor/test_exact_higher_order.py b/tests/test_taylor/test_exact_higher_order.py index 7b82eec5..668b1a32 100644 --- a/tests/test_taylor/test_exact_higher_order.py +++ b/tests/test_taylor/test_exact_higher_order.py @@ -1,18 +1,18 @@ """Test the exactness of differentiation-based routines on first-order problems.""" +from probdiffeq import taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing -from probdiffeq.taylor import autodiff @testing.case() -def case_forward_mode_recursive(): - return autodiff.forward_mode_recursive +def case_odejet_via_jvp(): + return taylor.odejet_via_jvp @testing.case() def case_taylor_mode_scan(): - return autodiff.taylor_mode_scan + return taylor.odejet_padded_scan @testing.fixture(name="pb_with_solution") diff --git a/tests/test_taylor/test_inexact_first_order.py b/tests/test_taylor/test_inexact_first_order.py index 1771c033..e23f0e7e 100644 --- a/tests/test_taylor/test_inexact_first_order.py +++ b/tests/test_taylor/test_inexact_first_order.py @@ -1,23 +1,23 @@ """Tests for inexact approximations for first-order problems.""" +from probdiffeq import taylor from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing from probdiffeq.impl import impl -from probdiffeq.taylor import autodiff, estim @testing.case() def case_runge_kutta_starter(): if impl.impl_name != "isotropic": testing.skip(reason="Runge-Kutta starters currently require isotropic SSMs.") - return estim.make_runge_kutta_starter(dt=0.01) + return taylor.runge_kutta_starter(dt=0.01) @testing.fixture(name="pb_with_solution") def fixture_pb_with_solution(): vf, (u0,), (t0, _) = ode.ivp_lotka_volterra() - solution = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=3) + solution = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3) return (vf, (u0,), t0), solution