From 1d9109024149aa7a70208a301a16b59a466ff1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 5 Oct 2023 13:21:24 +0200 Subject: [PATCH] Rename taylor_mode() to taylor_mode_scan() and move to toplevel (#666) * taylor_mode() is taylor_mode_scan() now * Forward mode is forward_mode_recursive * Updated changelog * Updated linter dependencies and rerun * Doc update * Updated index and removed 3.12 * Moved taylor to toplevel * Update docs * Show annotations in docs * Removed print from tests * Reformat notebook * Fixed tests --- .pre-commit-config.yaml | 11 ++--- README.md | 16 +++---- docs/api_docs/solvers/taylor/affine.md | 1 - docs/api_docs/solvers/taylor/autodiff.md | 1 - docs/api_docs/solvers/taylor/estim.md | 1 - docs/api_docs/taylor/affine.md | 1 + docs/api_docs/taylor/autodiff.md | 1 + docs/api_docs/taylor/estim.md | 1 + docs/benchmarks/hires/plot.md | 2 +- docs/benchmarks/hires/run_hires.py | 4 +- docs/benchmarks/lotkavolterra/plot.md | 2 +- .../lotkavolterra/run_lotkavolterra.py | 4 +- docs/benchmarks/pleiades/plot.md | 2 +- docs/benchmarks/pleiades/run_pleiades.py | 4 +- docs/benchmarks/taylor_fitzhughnagumo/plot.md | 2 +- .../run_taylor_fitzhughnagumo.py | 14 +++--- docs/benchmarks/taylor_node/plot.md | 2 +- .../benchmarks/taylor_node/run_taylor_node.py | 14 +++--- docs/benchmarks/taylor_pleiades/plot.md | 2 +- .../taylor_pleiades/run_taylor_pleiades.py | 14 +++--- docs/benchmarks/vanderpol/plot.md | 2 +- docs/benchmarks/vanderpol/run_vanderpol.py | 4 +- docs/dev_docs/changelog.md | 10 ++++- docs/dev_docs/public_api.md | 1 + .../neural_ode.ipynb | 2 +- .../neural_ode.md | 4 +- .../physics_enhanced_regression_1.md | 2 +- .../physics_enhanced_regression_2.ipynb | 6 +-- .../physics_enhanced_regression_2.md | 8 ++-- .../conditioning-on-zero-residual.ipynb | 4 +- .../conditioning-on-zero-residual.md | 6 +-- .../dynamic_output_scales.ipynb | 2 +- .../dynamic_output_scales.md | 4 +- .../posterior_uncertainties.ipynb | 4 +- .../posterior_uncertainties.md | 6 +-- .../second_order_problems.ipynb | 6 +-- .../second_order_problems.md | 8 ++-- docs/examples_solver_config/smoothing.ipynb | 4 +- docs/examples_solver_config/smoothing.md | 6 +-- docs/getting_started/easy_example.ipynb | 4 +- docs/getting_started/easy_example.md | 4 +- .../transitioning_from_other_packages.md | 2 +- docs/getting_started/troubleshooting.md | 2 +- docs/index.md | 16 +++---- makefile | 6 +-- mkdocs.yml | 20 ++------- probdiffeq/{solvers => }/taylor/__init__.py | 0 probdiffeq/{solvers => }/taylor/affine.py | 4 +- probdiffeq/{solvers => }/taylor/autodiff.py | 42 +++++++++--------- probdiffeq/{solvers => }/taylor/estim.py | 0 pyproject.toml | 9 ++-- .../test_fixed_grid_vs_save_every_step.py | 4 +- .../test_save_at_vs_save_every_step.py | 4 +- tests/test_ivpsolve/test_save_every_step.py | 6 +-- tests/test_ivpsolve/test_solution_object.py | 4 +- ...test_terminal_values_vs_save_every_step.py | 4 +- ...test_mle_calibration_vs_calibrationfree.py | 10 ++--- .../test_warnings_for_wrong_strategies.py | 6 +-- .../test_log_marginal_likelihood.py | 6 +-- ...log_marginal_likelihood_terminal_values.py | 4 +- .../test_solution/test_offgrid_marginals.py | 4 +- .../test_solvers/test_solution/test_sample.py | 4 +- .../test_rmse_of_correction.py | 4 +- .../test_cubature/test_equivalence.py | 1 - .../test_filter_vs_smoother_rmse.py | 4 +- ...test_smoother_vs_fixedpoint_equivalence.py | 4 +- .../test_taylor/__init__.py | 0 .../data/generate_reference_solutions.py | 12 +++-- .../data/three_body_first_solution.npy | Bin .../data/van_der_pol_second_solution.npy | Bin .../test_taylor/test_affine_recursion.py | 4 +- .../test_taylor/test_exact_first_order.py | 14 +++--- .../test_taylor/test_exact_higher_order.py | 14 +++--- .../test_taylor/test_inexact_first_order.py | 4 +- tests/test_util/test_ibm_discrete.py | 1 - 75 files changed, 198 insertions(+), 217 deletions(-) delete mode 100644 docs/api_docs/solvers/taylor/affine.md delete mode 100644 docs/api_docs/solvers/taylor/autodiff.md delete mode 100644 docs/api_docs/solvers/taylor/estim.md create mode 100644 docs/api_docs/taylor/affine.md create mode 100644 docs/api_docs/taylor/autodiff.md create mode 100644 docs/api_docs/taylor/estim.md rename probdiffeq/{solvers => }/taylor/__init__.py (100%) rename probdiffeq/{solvers => }/taylor/affine.py (78%) rename probdiffeq/{solvers => }/taylor/autodiff.py (85%) rename probdiffeq/{solvers => }/taylor/estim.py (100%) rename tests/{test_solvers => }/test_taylor/__init__.py (100%) rename tests/{test_solvers => }/test_taylor/data/generate_reference_solutions.py (83%) rename tests/{test_solvers => }/test_taylor/data/three_body_first_solution.npy (100%) rename tests/{test_solvers => }/test_taylor/data/van_der_pol_second_solution.npy (100%) rename tests/{test_solvers => }/test_taylor/test_affine_recursion.py (87%) rename tests/{test_solvers => }/test_taylor/test_exact_first_order.py (84%) rename tests/{test_solvers => }/test_taylor/test_exact_higher_order.py (75%) rename tests/{test_solvers => }/test_taylor/test_inexact_first_order.py (92%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4cf753fc..0e133956 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,22 +8,23 @@ repos: - id: end-of-file-fixer - id: check-merge-conflict - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black-jupyter language_version: python3 - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.284 + rev: v0.0.292 hooks: - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/mwouts/jupytext - rev: v1.15.0 + rev: v1.15.2 hooks: - id: jupytext - files: (docs/).+ + files: ^(docs/(benchmarks|examples_solver_config|examples_parameter_estimation|getting_started)/).+.ipynb args: [--sync] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.5.1 hooks: - id: mypy args: [--ignore-missing-imports] diff --git a/README.md b/README.md index ffbab458..a2fbab45 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,17 @@ # probdiffeq - -[![PyPi Version](https://img.shields.io/pypi/v/probdiffeq.svg?style=flat-square&color=darkgray)](https://pypi.org/project/probdiffeq/) -[![gh-actions](https://img.shields.io/github/actions/workflow/status/pnkraemer/probdiffeq/ci.yaml?branch=main&style=flat-square)](https://github.com/pnkraemer/probdiffeq/actions?query=workflow%3Aci) -License Badge -[![GitHub stars](https://img.shields.io/github/stars/pnkraemer/probdiffeq.svg?style=flat-square&logo=github&label=Stars&logoColor=white)](https://github.com/pnkraemer/probdiffeq) -![Python](https://img.shields.io/badge/python-3.9+-black.svg?style=flat-square) - +[![Actions status](https://github.com/pnkraemer/probdiffeq/workflows/ci/badge.svg)](https://github.com/pnkraemer/probdiffeq/actions) +[![image](https://img.shields.io/pypi/v/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq) +[![image](https://img.shields.io/pypi/l/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq) +[![image](https://img.shields.io/pypi/pyversions/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq) ## Probabilistic solvers for differential equations in JAX ProbDiffEq implements adaptive probabilistic numerical solvers for initial value problems. It inherits automatic differentiation, vectorisation, and GPU capability from JAX. -Features include: + +**Features include:** * Stable implementation * Calibration, step-size adaptation, and checkpointing @@ -26,8 +24,6 @@ Features include: and many more. - - * **AN EASY EXAMPLE:** [LINK](https://pnkraemer.github.io/probdiffeq/getting_started/easy_example/) * **EXAMPLES:** [LINK](https://pnkraemer.github.io/probdiffeq/examples_solver_config/posterior_uncertainties/) * **CHOOSING A SOLVER:** [LINK](https://pnkraemer.github.io/probdiffeq/getting_started/choosing_a_solver/) diff --git a/docs/api_docs/solvers/taylor/affine.md b/docs/api_docs/solvers/taylor/affine.md deleted file mode 100644 index ec4090dc..00000000 --- a/docs/api_docs/solvers/taylor/affine.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.solvers.taylor.affine diff --git a/docs/api_docs/solvers/taylor/autodiff.md b/docs/api_docs/solvers/taylor/autodiff.md deleted file mode 100644 index 77ba0796..00000000 --- a/docs/api_docs/solvers/taylor/autodiff.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.solvers.taylor.autodiff diff --git a/docs/api_docs/solvers/taylor/estim.md b/docs/api_docs/solvers/taylor/estim.md deleted file mode 100644 index a5f743dc..00000000 --- a/docs/api_docs/solvers/taylor/estim.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.solvers.taylor.estim diff --git a/docs/api_docs/taylor/affine.md b/docs/api_docs/taylor/affine.md new file mode 100644 index 00000000..79068433 --- /dev/null +++ b/docs/api_docs/taylor/affine.md @@ -0,0 +1 @@ +::: probdiffeq.taylor.affine diff --git a/docs/api_docs/taylor/autodiff.md b/docs/api_docs/taylor/autodiff.md new file mode 100644 index 00000000..20e65aff --- /dev/null +++ b/docs/api_docs/taylor/autodiff.md @@ -0,0 +1 @@ +::: probdiffeq.taylor.autodiff diff --git a/docs/api_docs/taylor/estim.md b/docs/api_docs/taylor/estim.md new file mode 100644 index 00000000..2eafa6da --- /dev/null +++ b/docs/api_docs/taylor/estim.md @@ -0,0 +1 @@ +::: probdiffeq.taylor.estim diff --git a/docs/benchmarks/hires/plot.md b/docs/benchmarks/hires/plot.md index 475918c0..f7ebea70 100644 --- a/docs/benchmarks/hires/plot.md +++ b/docs/benchmarks/hires/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index 6ca83294..743e10bb 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -22,7 +22,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -105,7 +105,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num_derivatives) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives) init = solver.initial_condition(tcoeffs, output_scale=1.0) # Solve diff --git a/docs/benchmarks/lotkavolterra/plot.md b/docs/benchmarks/lotkavolterra/plot.md index 6eaffe4f..7256901b 100644 --- a/docs/benchmarks/lotkavolterra/plot.md +++ b/docs/benchmarks/lotkavolterra/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py index b1cb3eba..301220bc 100644 --- a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py +++ b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py @@ -23,7 +23,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -93,7 +93,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num_derivatives) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives) output_scale = 1.0 * jnp.ones((2,)) if implementation == "blockdiag" else 1.0 init = solver.initial_condition(tcoeffs, output_scale=output_scale) diff --git a/docs/benchmarks/pleiades/plot.md b/docs/benchmarks/pleiades/plot.md index 90852516..89cfb248 100644 --- a/docs/benchmarks/pleiades/plot.md +++ b/docs/benchmarks/pleiades/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/pleiades/run_pleiades.py b/docs/benchmarks/pleiades/run_pleiades.py index 3e805df0..86bc8b35 100644 --- a/docs/benchmarks/pleiades/run_pleiades.py +++ b/docs/benchmarks/pleiades/run_pleiades.py @@ -23,7 +23,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -117,7 +117,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num_derivatives - 1) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1) init = solver.initial_condition(tcoeffs, output_scale=1.0) # Solve diff --git a/docs/benchmarks/taylor_fitzhughnagumo/plot.md b/docs/benchmarks/taylor_fitzhughnagumo/plot.md index 52beb7a9..31f64553 100644 --- a/docs/benchmarks/taylor_fitzhughnagumo/plot.md +++ b/docs/benchmarks/taylor_fitzhughnagumo/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py b/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py index ca4364cb..0de07bb8 100644 --- a/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py +++ b/docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py @@ -15,7 +15,7 @@ from jax import config from probdiffeq.impl import impl -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -57,13 +57,13 @@ def timer(fun, /): return timer -def taylor_mode() -> Callable: +def taylor_mode_scan() -> Callable: """Taylor-mode estimation.""" vf_auto, (u0,) = _fitzhugh_nagumo() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -93,13 +93,13 @@ def estimate(num): return estimate -def forward_mode() -> Callable: +def forward_mode_recursive() -> Callable: """Forward-mode estimation.""" vf_auto, (u0,) = _fitzhugh_nagumo() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num) + tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -153,8 +153,8 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: if __name__ == "__main__": set_jax_config() algorithms = { - r"Forward-mode": forward_mode(), - r"Taylor-mode (scan)": taylor_mode(), + r"Forward-mode": forward_mode_recursive(), + r"Taylor-mode (scan)": taylor_mode_scan(), r"Taylor-mode (unroll)": taylor_mode_unroll(), r"Taylor-mode (doubling)": taylor_mode_doubling(), } diff --git a/docs/benchmarks/taylor_node/plot.md b/docs/benchmarks/taylor_node/plot.md index ccfa85c9..a801375a 100644 --- a/docs/benchmarks/taylor_node/plot.md +++ b/docs/benchmarks/taylor_node/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/taylor_node/run_taylor_node.py b/docs/benchmarks/taylor_node/run_taylor_node.py index 848030ca..f999f282 100644 --- a/docs/benchmarks/taylor_node/run_taylor_node.py +++ b/docs/benchmarks/taylor_node/run_taylor_node.py @@ -16,7 +16,7 @@ from jax import config from probdiffeq.impl import impl -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -58,13 +58,13 @@ def timer(fun, /): return timer -def taylor_mode() -> Callable: +def taylor_mode_scan() -> Callable: """Taylor-mode estimation.""" vf_auto, (u0,) = _node() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -94,13 +94,13 @@ def estimate(num): return estimate -def forward_mode() -> Callable: +def forward_mode_recursive() -> Callable: """Forward-mode estimation.""" vf_auto, (u0,) = _node() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num) + tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -167,8 +167,8 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: set_jax_config() backend.select("jax") algorithms = { - r"Forward-mode": forward_mode(), - r"Taylor-mode (scan)": taylor_mode(), + r"Forward-mode": forward_mode_recursive(), + r"Taylor-mode (scan)": taylor_mode_scan(), r"Taylor-mode (unroll)": taylor_mode_unroll(), r"Taylor-mode (doubling)": taylor_mode_doubling(), } diff --git a/docs/benchmarks/taylor_pleiades/plot.md b/docs/benchmarks/taylor_pleiades/plot.md index a1394594..c50f38fa 100644 --- a/docs/benchmarks/taylor_pleiades/plot.md +++ b/docs/benchmarks/taylor_pleiades/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py b/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py index eb277b83..b6260136 100644 --- a/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py +++ b/docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py @@ -15,7 +15,7 @@ from jax import config from probdiffeq.impl import impl -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -57,13 +57,13 @@ def timer(fun, /): return timer -def taylor_mode() -> Callable: +def taylor_mode_scan() -> Callable: """Taylor-mode estimation.""" vf_auto, (u0, du0) = _pleiades() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -81,13 +81,13 @@ def estimate(num): return estimate -def forward_mode() -> Callable: +def forward_mode_recursive() -> Callable: """Forward-mode estimation.""" vf_auto, (u0, du0) = _pleiades() @functools.partial(jax.jit, static_argnames=["num"]) def estimate(num): - tcoeffs = autodiff.forward_mode(vf_auto, (u0, du0), num=num) + tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0, du0), num=num) return jax.block_until_ready(tcoeffs) return estimate @@ -161,8 +161,8 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: if __name__ == "__main__": set_jax_config() algorithms = { - r"Forward-mode": forward_mode(), - r"Taylor-mode (scan)": taylor_mode(), + r"Forward-mode": forward_mode_recursive(), + r"Taylor-mode (scan)": taylor_mode_scan(), r"Taylor-mode (unroll)": taylor_mode_unroll(), } diff --git a/docs/benchmarks/vanderpol/plot.md b/docs/benchmarks/vanderpol/plot.md index 9c615bac..da6c960e 100644 --- a/docs/benchmarks/vanderpol/plot.md +++ b/docs/benchmarks/vanderpol/plot.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index fdfc68ce..f9a415fa 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -22,7 +22,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.util.doc_util import info @@ -96,7 +96,7 @@ def param_to_solution(tol): # Initial state vf_auto = functools.partial(vf_probdiffeq, t=t0) - tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num_derivatives - 1) + tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1) init = solver.initial_condition(tcoeffs, output_scale=1.0) # Solve diff --git a/docs/dev_docs/changelog.md b/docs/dev_docs/changelog.md index 1a94a82e..0b3c7905 100644 --- a/docs/dev_docs/changelog.md +++ b/docs/dev_docs/changelog.md @@ -1,11 +1,17 @@ # Changelog -## v0.2.3 +## v0.3.0 **New features:** * A new function `taylor_mode_unroll` implements Taylor-series estimation without a `scan`. +**Breaking changes:** + +* What was formerly `taylor_mode()`, is now `taylor_mode_scan()` and stands in contrast to the new `taylor_mode_unroll()`. +* What was formerly `forward_mode()`, is now `forward_mode_recursive()`. +* The entire `taylor` subpackage moved to top-level. Instead of `from probdiffeq.solvers.taylor import ...`, use `from probdiffeq.taylor import ...`. + ## v0.2.2 @@ -16,7 +22,7 @@ This release was due to issues in the publishing workflow. **Breaking changes:** * The input-argument to `taylor_mode_doubling` is `num_doublings` instead of `num`. - This argument behaves differently to e.g., `taylor_mode(..., num)`. + This argument behaves differently to e.g., `taylor_mode_scan(..., num)`. ## v0.2.0 diff --git a/docs/dev_docs/public_api.md b/docs/dev_docs/public_api.md index e3ba60e5..ea03c611 100644 --- a/docs/dev_docs/public_api.md +++ b/docs/dev_docs/public_api.md @@ -8,6 +8,7 @@ At the moment, this affects the following: * `controls.py` * `adaptive.py` * `timestep.py` +* `taylor/*` * `solvers/*` * `impl.impl.select()` diff --git a/docs/examples_parameter_estimation/neural_ode.ipynb b/docs/examples_parameter_estimation/neural_ode.ipynb index dbcb943a..45ce0278 100644 --- a/docs/examples_parameter_estimation/neural_ode.ipynb +++ b/docs/examples_parameter_estimation/neural_ode.ipynb @@ -36,7 +36,7 @@ "from probdiffeq.impl import impl\n", "from probdiffeq.util.doc_util import notebook\n", "from probdiffeq.solvers import uncalibrated, solution\n", - "from probdiffeq.solvers.taylor import autodiff\n", + "from probdiffeq.taylor import autodiff\n", "from probdiffeq.solvers.strategies.components import corrections, priors\n", "from probdiffeq.solvers.strategies import smoothers" ] diff --git a/docs/examples_parameter_estimation/neural_ode.md b/docs/examples_parameter_estimation/neural_ode.md index 5c5f0da6..aa0108a4 100644 --- a/docs/examples_parameter_estimation/neural_ode.md +++ b/docs/examples_parameter_estimation/neural_ode.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -29,7 +29,7 @@ from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import uncalibrated, solution -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.solvers.strategies.components import corrections, priors from probdiffeq.solvers.strategies import smoothers ``` diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_1.md b/docs/examples_parameter_estimation/physics_enhanced_regression_1.md index de88f6cd..44613c5c 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_1.md +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_1.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.ipynb b/docs/examples_parameter_estimation/physics_enhanced_regression_2.ipynb index d825a80e..6f49b4d6 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.ipynb +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.ipynb @@ -103,7 +103,7 @@ "from probdiffeq.solvers import uncalibrated, solution\n", "from probdiffeq.solvers.strategies.components import corrections, priors\n", "from probdiffeq.solvers.strategies import filters\n", - "from probdiffeq.solvers.taylor import autodiff" + "from probdiffeq.taylor import autodiff" ] }, { @@ -204,7 +204,7 @@ " strategy = filters.filter_adaptive(ibm, ts0)\n", " solver = uncalibrated.solver(strategy)\n", "\n", - " tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (theta,), num=2)\n", + " tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2)\n", " output_scale = 10.0\n", " init = solver.initial_condition(tcoeffs, output_scale)\n", "\n", @@ -221,7 +221,7 @@ " solver = uncalibrated.solver(strategy)\n", " adaptive_solver = adaptive.adaptive(solver)\n", "\n", - " tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (theta,), num=2)\n", + " tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2)\n", " output_scale = 10.0\n", " init = solver.initial_condition(tcoeffs, output_scale)\n", " return ivpsolve.solve_and_save_at(\n", diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.md b/docs/examples_parameter_estimation/physics_enhanced_regression_2.md index bf1302bd..a4cd57fb 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.md +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -104,7 +104,7 @@ from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import uncalibrated, solution from probdiffeq.solvers.strategies.components import corrections, priors from probdiffeq.solvers.strategies import filters -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff ``` ```python @@ -160,7 +160,7 @@ def solve_fixed(theta, *, ts): strategy = filters.filter_adaptive(ibm, ts0) solver = uncalibrated.solver(strategy) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (theta,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2) output_scale = 10.0 init = solver.initial_condition(tcoeffs, output_scale) @@ -177,7 +177,7 @@ def solve_adaptive(theta, *, save_at): solver = uncalibrated.solver(strategy) adaptive_solver = adaptive.adaptive(solver) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (theta,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2) output_scale = 10.0 init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_and_save_at( diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.ipynb b/docs/examples_solver_config/conditioning-on-zero-residual.ipynb index b705296c..ca715ddb 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.ipynb +++ b/docs/examples_solver_config/conditioning-on-zero-residual.ipynb @@ -35,7 +35,7 @@ "from probdiffeq.impl import impl\n", "from probdiffeq.util.doc_util import notebook\n", "from probdiffeq.solvers import calibrated, uncalibrated, solution, markov\n", - "from probdiffeq.solvers.taylor import autodiff\n", + "from probdiffeq.taylor import autodiff\n", "from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint\n", "from probdiffeq.solvers.strategies.components import corrections, priors" ] @@ -117,7 +117,7 @@ "markov_seq_prior = markov.MarkovSeq(init_raw, transitions)\n", "\n", "\n", - "tcoeffs = autodiff.taylor_mode(\n", + "tcoeffs = autodiff.taylor_mode_scan(\n", " lambda y: vector_field(y, t=t0), (u0,), num=NUM_DERIVATIVES\n", ")\n", "init_tcoeffs = impl.ssm_util.normal_from_tcoeffs(\n", diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.md b/docs/examples_solver_config/conditioning-on-zero-residual.md index fef5558b..c8a52cd9 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.md +++ b/docs/examples_solver_config/conditioning-on-zero-residual.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -29,7 +29,7 @@ from probdiffeq import controls, ivpsolve, timestep, adaptive from probdiffeq.impl import impl from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import calibrated, uncalibrated, solution, markov -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint from probdiffeq.solvers.strategies.components import corrections, priors ``` @@ -70,7 +70,7 @@ init_raw, transitions = priors.ibm_discretised( markov_seq_prior = markov.MarkovSeq(init_raw, transitions) -tcoeffs = autodiff.taylor_mode( +tcoeffs = autodiff.taylor_mode_scan( lambda y: vector_field(y, t=t0), (u0,), num=NUM_DERIVATIVES ) init_tcoeffs = impl.ssm_util.normal_from_tcoeffs( diff --git a/docs/examples_solver_config/dynamic_output_scales.ipynb b/docs/examples_solver_config/dynamic_output_scales.ipynb index 716a076d..312a9844 100644 --- a/docs/examples_solver_config/dynamic_output_scales.ipynb +++ b/docs/examples_solver_config/dynamic_output_scales.ipynb @@ -37,7 +37,7 @@ "\n", "from probdiffeq import ivpsolve\n", "from probdiffeq.impl import impl\n", - "from probdiffeq.solvers.taylor import affine\n", + "from probdiffeq.taylor import affine\n", "from probdiffeq.util.doc_util import notebook\n", "from probdiffeq.solvers import calibrated\n", "from probdiffeq.solvers.strategies import filters\n", diff --git a/docs/examples_solver_config/dynamic_output_scales.md b/docs/examples_solver_config/dynamic_output_scales.md index 23bc9887..c7634d53 100644 --- a/docs/examples_solver_config/dynamic_output_scales.md +++ b/docs/examples_solver_config/dynamic_output_scales.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -31,7 +31,7 @@ from jax.config import config from probdiffeq import ivpsolve from probdiffeq.impl import impl -from probdiffeq.solvers.taylor import affine +from probdiffeq.taylor import affine from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters diff --git a/docs/examples_solver_config/posterior_uncertainties.ipynb b/docs/examples_solver_config/posterior_uncertainties.ipynb index aeecf614..e8fc9074 100644 --- a/docs/examples_solver_config/posterior_uncertainties.ipynb +++ b/docs/examples_solver_config/posterior_uncertainties.ipynb @@ -33,7 +33,7 @@ "from probdiffeq.impl import impl\n", "from probdiffeq.util.doc_util import notebook\n", "from probdiffeq.solvers import calibrated, solution, markov\n", - "from probdiffeq.solvers.taylor import autodiff\n", + "from probdiffeq.taylor import autodiff\n", "from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint\n", "from probdiffeq.solvers.strategies.components import corrections, priors" ] @@ -160,7 +160,7 @@ "source": [ "dt0 = timestep.initial(lambda y: vf(y, t=t0), (u0,))\n", "\n", - "tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4)\n", + "tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)\n", "init = solver.initial_condition(tcoeffs, output_scale=1.0)\n", "sol = ivpsolve.solve_and_save_at(\n", " vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver\n", diff --git a/docs/examples_solver_config/posterior_uncertainties.md b/docs/examples_solver_config/posterior_uncertainties.md index 564a8fb7..40c6beab 100644 --- a/docs/examples_solver_config/posterior_uncertainties.md +++ b/docs/examples_solver_config/posterior_uncertainties.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -26,7 +26,7 @@ from probdiffeq import ivpsolve, adaptive, timestep from probdiffeq.impl import impl from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import calibrated, solution, markov -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint from probdiffeq.solvers.strategies.components import corrections, priors ``` @@ -72,7 +72,7 @@ ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) ```python dt0 = timestep.initial(lambda y: vf(y, t=t0), (u0,)) -tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4) +tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) init = solver.initial_condition(tcoeffs, output_scale=1.0) sol = ivpsolve.solve_and_save_at( vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver diff --git a/docs/examples_solver_config/second_order_problems.ipynb b/docs/examples_solver_config/second_order_problems.ipynb index 819cc338..31a0ba55 100644 --- a/docs/examples_solver_config/second_order_problems.ipynb +++ b/docs/examples_solver_config/second_order_problems.ipynb @@ -33,7 +33,7 @@ "from probdiffeq import adaptive, ivpsolve\n", "from probdiffeq.util.doc_util import notebook\n", "from probdiffeq.solvers import calibrated\n", - "from probdiffeq.solvers.taylor import autodiff\n", + "from probdiffeq.taylor import autodiff\n", "from probdiffeq.solvers.strategies.components import corrections, priors\n", "from probdiffeq.solvers.strategies import filters" ] @@ -99,7 +99,7 @@ "adaptive_solver_1st = adaptive.adaptive(solver_1st, atol=1e-5, rtol=1e-5)\n", "\n", "\n", - "tcoeffs = autodiff.taylor_mode(lambda y: vf_1(y, t=t0), (u0,), num=4)\n", + "tcoeffs = autodiff.taylor_mode_scan(lambda y: vf_1(y, t=t0), (u0,), num=4)\n", "init = solver_1st.initial_condition(tcoeffs, output_scale=1.0)" ] }, @@ -209,7 +209,7 @@ "adaptive_solver_2nd = adaptive.adaptive(solver_2nd, atol=1e-5, rtol=1e-5)\n", "\n", "\n", - "tcoeffs = autodiff.taylor_mode(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3)\n", + "tcoeffs = autodiff.taylor_mode_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3)\n", "init = solver_2nd.initial_condition(tcoeffs, output_scale=1.0)" ] }, diff --git a/docs/examples_solver_config/second_order_problems.md b/docs/examples_solver_config/second_order_problems.md index e5b26a6b..c707a77a 100644 --- a/docs/examples_solver_config/second_order_problems.md +++ b/docs/examples_solver_config/second_order_problems.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -26,7 +26,7 @@ from probdiffeq.impl import impl from probdiffeq import adaptive, ivpsolve from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import calibrated -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.solvers.strategies.components import corrections, priors from probdiffeq.solvers.strategies import filters ``` @@ -58,7 +58,7 @@ solver_1st = calibrated.mle(filters.filter_adaptive(ibm, ts0)) adaptive_solver_1st = adaptive.adaptive(solver_1st, atol=1e-5, rtol=1e-5) -tcoeffs = autodiff.taylor_mode(lambda y: vf_1(y, t=t0), (u0,), num=4) +tcoeffs = autodiff.taylor_mode_scan(lambda y: vf_1(y, t=t0), (u0,), num=4) init = solver_1st.initial_condition(tcoeffs, output_scale=1.0) ``` @@ -100,7 +100,7 @@ solver_2nd = calibrated.mle(filters.filter_adaptive(ibm, ts0)) adaptive_solver_2nd = adaptive.adaptive(solver_2nd, atol=1e-5, rtol=1e-5) -tcoeffs = autodiff.taylor_mode(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) +tcoeffs = autodiff.taylor_mode_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) init = solver_2nd.initial_condition(tcoeffs, output_scale=1.0) ``` diff --git a/docs/examples_solver_config/smoothing.ipynb b/docs/examples_solver_config/smoothing.ipynb index f7d19330..c32e6f69 100644 --- a/docs/examples_solver_config/smoothing.ipynb +++ b/docs/examples_solver_config/smoothing.ipynb @@ -36,7 +36,7 @@ "from probdiffeq.impl import impl\n", "from probdiffeq.util.doc_util import notebook\n", "from probdiffeq.solvers import calibrated, solution\n", - "from probdiffeq.solvers.taylor import autodiff\n", + "from probdiffeq.taylor import autodiff\n", "from probdiffeq.solvers.strategies.components import corrections, priors\n", "from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint" ] @@ -104,7 +104,7 @@ " return f(*ys, *f_args)\n", "\n", "\n", - "tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4)" + "tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)" ] }, { diff --git a/docs/examples_solver_config/smoothing.md b/docs/examples_solver_config/smoothing.md index 1b9cc8d2..890df16c 100644 --- a/docs/examples_solver_config/smoothing.md +++ b/docs/examples_solver_config/smoothing.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.15.0 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -30,7 +30,7 @@ from probdiffeq import ivpsolve, adaptive from probdiffeq.impl import impl from probdiffeq.util.doc_util import notebook from probdiffeq.solvers import calibrated, solution -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from probdiffeq.solvers.strategies.components import corrections, priors from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint ``` @@ -57,7 +57,7 @@ def vf(*ys, t): return f(*ys, *f_args) -tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4) +tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) ``` ## Terminal-value simulation diff --git a/docs/getting_started/easy_example.ipynb b/docs/getting_started/easy_example.ipynb index e1d085de..bbf0d314 100644 --- a/docs/getting_started/easy_example.ipynb +++ b/docs/getting_started/easy_example.ipynb @@ -33,7 +33,7 @@ "from probdiffeq.solvers import uncalibrated\n", "from probdiffeq.solvers.strategies import smoothers\n", "from probdiffeq.solvers.strategies.components import corrections, priors\n", - "from probdiffeq.solvers.taylor import autodiff\n", + "from probdiffeq.taylor import autodiff\n", "\n", "config.update(\"jax_platform_name\", \"cpu\")" ] @@ -186,7 +186,7 @@ }, "outputs": [], "source": [ - "tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4)\n", + "tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)\n", "output_scale = 1.0 # or any other value with the same shape\n", "init = solver.initial_condition(tcoeffs, output_scale)" ] diff --git a/docs/getting_started/easy_example.md b/docs/getting_started/easy_example.md index 9e0ff017..047be20e 100644 --- a/docs/getting_started/easy_example.md +++ b/docs/getting_started/easy_example.md @@ -12,7 +12,7 @@ from probdiffeq.impl import impl from probdiffeq.solvers import uncalibrated from probdiffeq.solvers.strategies import smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff config.update("jax_platform_name", "cpu") ``` @@ -87,7 +87,7 @@ and to wrapping this approximation into a state-space-model variable. Use the following functions: ```python -tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4) +tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) output_scale = 1.0 # or any other value with the same shape init = solver.initial_condition(tcoeffs, output_scale) ``` diff --git a/docs/getting_started/transitioning_from_other_packages.md b/docs/getting_started/transitioning_from_other_packages.md index 6fdd0681..4fa4066f 100644 --- a/docs/getting_started/transitioning_from_other_packages.md +++ b/docs/getting_started/transitioning_from_other_packages.md @@ -39,7 +39,7 @@ ProbDiffEq can reproduce most of the implementations in Tornadox: | `solver.solve()` | `solve_and_save_every_step()` | Try `solve_and_save_at()` instead. | | `solver.simulate_final_state()` | `simulate_terminal_values()` | ProbDiffEq compiles the whole loop; it will be much faster. | | `solver.solution_generator()` | Work in progress. | | -| `init.TaylorMode()` | `taylor.autodiff.taylor_mode` | Consider `taylor.autodiff.forward_mode()` for low numbers of derivatives and `taylor.autodiff.taylor_mode_doubling()` for (absurdly) high numbers of derivatives | +| `init.TaylorMode()` | `taylor.autodiff.taylor_mode` | Consider `taylor.autodiff.forward_mode_recursive()` for low numbers of derivatives and `taylor.autodiff.taylor_mode_doubling()` for (absurdly) high numbers of derivatives | | `init.RungeKutta()` | `taylor.estim.make_runge_kutta_starter()` | | diff --git a/docs/getting_started/troubleshooting.md b/docs/getting_started/troubleshooting.md index 5163274f..be6e86e8 100644 --- a/docs/getting_started/troubleshooting.md +++ b/docs/getting_started/troubleshooting.md @@ -5,7 +5,7 @@ If a solution routine takes surprisingly long to compile but then executes quickly, it may be due to the choice of Taylor-coefficient computation. Some functions in `probdiffeq.taylor` unroll a (small) loop. -To avoid this, use `probdiffeq.taylor.autodiff.taylor_mode()` +To avoid this, use `probdiffeq.taylor.autodiff.taylor_mode_scan()` (which is implemented with a scan). If the problem persists, reduce the number of derivatives (if that is appropriate for your integration problem) diff --git a/docs/index.md b/docs/index.md index ffbab458..a2fbab45 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,19 +1,17 @@ # probdiffeq - -[![PyPi Version](https://img.shields.io/pypi/v/probdiffeq.svg?style=flat-square&color=darkgray)](https://pypi.org/project/probdiffeq/) -[![gh-actions](https://img.shields.io/github/actions/workflow/status/pnkraemer/probdiffeq/ci.yaml?branch=main&style=flat-square)](https://github.com/pnkraemer/probdiffeq/actions?query=workflow%3Aci) -License Badge -[![GitHub stars](https://img.shields.io/github/stars/pnkraemer/probdiffeq.svg?style=flat-square&logo=github&label=Stars&logoColor=white)](https://github.com/pnkraemer/probdiffeq) -![Python](https://img.shields.io/badge/python-3.9+-black.svg?style=flat-square) - +[![Actions status](https://github.com/pnkraemer/probdiffeq/workflows/ci/badge.svg)](https://github.com/pnkraemer/probdiffeq/actions) +[![image](https://img.shields.io/pypi/v/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq) +[![image](https://img.shields.io/pypi/l/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq) +[![image](https://img.shields.io/pypi/pyversions/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq) ## Probabilistic solvers for differential equations in JAX ProbDiffEq implements adaptive probabilistic numerical solvers for initial value problems. It inherits automatic differentiation, vectorisation, and GPU capability from JAX. -Features include: + +**Features include:** * Stable implementation * Calibration, step-size adaptation, and checkpointing @@ -26,8 +24,6 @@ Features include: and many more. - - * **AN EASY EXAMPLE:** [LINK](https://pnkraemer.github.io/probdiffeq/getting_started/easy_example/) * **EXAMPLES:** [LINK](https://pnkraemer.github.io/probdiffeq/examples_solver_config/posterior_uncertainties/) * **CHOOSING A SOLVER:** [LINK](https://pnkraemer.github.io/probdiffeq/getting_started/choosing_a_solver/) diff --git a/makefile b/makefile index 7a552e38..fcdd7dcc 100644 --- a/makefile +++ b/makefile @@ -4,7 +4,7 @@ format: black --quiet . isort --quiet . - jupytext --quiet --sync docs/quickstart/*.ipynb + jupytext --quiet --sync docs/getting_started/*.ipynb jupytext --quiet --sync docs/examples_solver_config/*.ipynb jupytext --quiet --sync docs/examples_parameter_estimation/*.ipynb jupytext --quiet --sync docs/benchmarks/hires/*.ipynb @@ -25,8 +25,8 @@ test: IMPL=scalar pytest -n auto -v # parallelise, verbose output example: - jupytext --quiet --sync docs/quickstart/*.ipynb - jupytext --quiet --execute docs/quickstart/*.ipynb + jupytext --quiet --sync docs/getting_started/*.ipynb + jupytext --quiet --execute docs/getting_started/*.ipynb jupytext --quiet --sync docs/examples_solver_config/* jupytext --quiet --execute docs/examples_solver_config/* jupytext --quiet --sync docs/examples_solver_config/* diff --git a/mkdocs.yml b/mkdocs.yml index d4dea1cd..d911dd64 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -63,20 +63,8 @@ plugins: handlers: python: options: - show_root_heading: true - show_root_toc_entry: true - show_root_full_path: true - show_root_members_full_path: true - show_object_full_path: false - show_category_heading: true docstring_style: numpy - show_if_no_docstring: true - members_order: alphabetical - annotations_path: brief - show_signature: true show_signature_annotations: true - separate_signature: false - docstring_section_style: list - exclude: glob: - getting_started/easy_example.md @@ -147,14 +135,14 @@ nav: - smoothers: api_docs/solvers/strategies/smoothers.md - fixedpoint: api_docs/solvers/strategies/fixedpoint.md - discrete: api_docs/solvers/strategies/discrete.md - - taylor: - - affine: api_docs/solvers/taylor/affine.md - - autodiff: api_docs/solvers/taylor/autodiff.md - - estim: api_docs/solvers/taylor/estim.md - calibrated: api_docs/solvers/calibrated.md - uncalibrated: api_docs/solvers/uncalibrated.md - solution: api_docs/solvers/solution.md - markov: api_docs/solvers/markov.md + - taylor: + - affine: api_docs/taylor/affine.md + - autodiff: api_docs/taylor/autodiff.md + - estim: api_docs/taylor/estim.md - DEVELOPER DOCUMENTATION: - Changelog: dev_docs/changelog.md - dev_docs/public_api.md diff --git a/probdiffeq/solvers/taylor/__init__.py b/probdiffeq/taylor/__init__.py similarity index 100% rename from probdiffeq/solvers/taylor/__init__.py rename to probdiffeq/taylor/__init__.py diff --git a/probdiffeq/solvers/taylor/affine.py b/probdiffeq/taylor/affine.py similarity index 78% rename from probdiffeq/solvers/taylor/affine.py rename to probdiffeq/taylor/affine.py index 3563cb9c..27d57553 100644 --- a/probdiffeq/solvers/taylor/affine.py +++ b/probdiffeq/taylor/affine.py @@ -1,6 +1,5 @@ r"""Taylor-expand the solution of an initial value problem (IVP).""" -import functools from typing import Callable import jax @@ -8,8 +7,7 @@ import jax.experimental.ode -@functools.partial(jax.jit, static_argnums=[0], static_argnames=["num"]) -def affine_recursion(vf: Callable, initial_values: tuple, /, num: int): +def affine_recursion(vf: Callable, initial_values: tuple[jax.Array, ...], /, num: int): """Evaluate the Taylor series of an affine differential equation. !!! warning "Compilation time" diff --git a/probdiffeq/solvers/taylor/autodiff.py b/probdiffeq/taylor/autodiff.py similarity index 85% rename from probdiffeq/solvers/taylor/autodiff.py rename to probdiffeq/taylor/autodiff.py index ecb2177b..588097f0 100644 --- a/probdiffeq/solvers/taylor/autodiff.py +++ b/probdiffeq/taylor/autodiff.py @@ -9,7 +9,7 @@ import jax.numpy as jnp -def taylor_mode(vf: Callable, initial_values: tuple, /, num: int): +def taylor_mode_scan(vf: Callable, inits: tuple[jax.Array, ...], /, num: int): """Taylor-expand the solution of an IVP with Taylor-mode differentiation. Other than `taylor_mode_unroll()`, this function implements the loop via a scan, @@ -21,11 +21,11 @@ def taylor_mode(vf: Callable, initial_values: tuple, /, num: int): Consult the benchmarks if performance is critical. """ # Number of positional arguments in f - num_arguments = len(initial_values) + num_arguments = len(inits) # Initial Taylor series (u_0, u_1, ..., u_k) - primals = vf(*initial_values) - taylor_coeffs = [*initial_values, primals] + primals = vf(*inits) + taylor_coeffs = [*inits, primals] def body(tcoeffs, _): # Pad the Taylor coefficients in zeros, call jet, and return the solution. @@ -33,12 +33,12 @@ def body(tcoeffs, _): # is independent of the $i+j$th input coefficient # (see also the explanation in taylor_mode_doubling) series = _subsets(tcoeffs[1:], num_arguments) # for high-order ODEs - p, s_new = jax.experimental.jet.jet(vf, primals=initial_values, series=series) + p, s_new = jax.experimental.jet.jet(vf, primals=inits, series=series) # The final values in s_new are nonsensical # (well, they are not; but we don't care about them) # so we remove them - tcoeffs = [*initial_values, p, *s_new[:-1]] + tcoeffs = [*inits, p, *s_new[:-1]] return tcoeffs, None # Pad the initial Taylor series with zeros @@ -56,12 +56,12 @@ def body(tcoeffs, _): return taylor_coeffs -def taylor_mode_unroll(vf: Callable, initial_values: tuple, /, num: int): +def taylor_mode_unroll(vf: Callable, inits: tuple[jax.Array, ...], /, num: int): """Taylor-expand the solution of an IVP with Taylor-mode differentiation. - Other than `taylor_mode()`, this function does not depend on zero-padding + Other than `taylor_mode_scan()`, this function does not depend on zero-padding the coefficients at the price of unrolling a loop of length `num-1`. - It is expected to compile more slowly than `taylor_mode()`, + It is expected to compile more slowly than `taylor_mode_scan()`, but execute more quickly. The differences should be small. @@ -72,16 +72,16 @@ def taylor_mode_unroll(vf: Callable, initial_values: tuple, /, num: int): """ # Number of positional arguments in f - num_arguments = len(initial_values) + num_arguments = len(inits) # Initial Taylor series (u_0, u_1, ..., u_k) - primals = vf(*initial_values) - taylor_coeffs = [*initial_values, primals] + primals = vf(*inits) + taylor_coeffs = [*inits, primals] for _ in range(num - 1): series = _subsets(taylor_coeffs[1:], num_arguments) # for high-order ODEs - p, s_new = jax.experimental.jet.jet(vf, primals=initial_values, series=series) - taylor_coeffs = [*initial_values, p, *s_new] + p, s_new = jax.experimental.jet.jet(vf, primals=inits, series=series) + taylor_coeffs = [*inits, p, *s_new] return taylor_coeffs @@ -111,18 +111,18 @@ def mask(i): return [x[mask(k) : mask(k + 1 - n)] for k in range(n)] -def forward_mode(vf: Callable, initial_values: tuple, /, num: int): - """Taylor-expand the solution of an IVP with forward-mode differentiation. +def forward_mode_recursive(vf: Callable, inits: tuple[jax.Array, ...], /, num: int): + """Taylor-expand the solution of an IVP with recursive forward-mode differentiation. !!! warning "Compilation time" JIT-compiling this function unrolls a loop. """ g_n, g_0 = vf, vf - taylor_coeffs = [*initial_values, vf(*initial_values)] + taylor_coeffs = [*inits, vf(*inits)] for _ in range(num - 1): g_n = _fwd_recursion_iterate(fun_n=g_n, fun_0=g_0) - taylor_coeffs = [*taylor_coeffs, g_n(*initial_values)] + taylor_coeffs = [*taylor_coeffs, g_n(*inits)] return taylor_coeffs @@ -140,7 +140,9 @@ def df(*args): return jax.tree_util.Partial(df) -def taylor_mode_doubling(vf: Callable, initial_values: tuple, /, num_doublings: int): +def taylor_mode_doubling( + vf: Callable, inits: tuple[jax.Array, ...], /, num_doublings: int +): """Combine Taylor-mode differentiation and Newton's doubling. !!! warning "Warning: highly EXPERIMENTAL feature!" @@ -153,7 +155,7 @@ def taylor_mode_doubling(vf: Callable, initial_values: tuple, /, num_doublings: JIT-compiling this function unrolls a loop. """ - (u0,) = initial_values + (u0,) = inits zeros = jnp.zeros_like(u0) def jet_embedded(*c, degree): diff --git a/probdiffeq/solvers/taylor/estim.py b/probdiffeq/taylor/estim.py similarity index 100% rename from probdiffeq/solvers/taylor/estim.py rename to probdiffeq/taylor/estim.py diff --git a/pyproject.toml b/pyproject.toml index 4dcd3bcd..3195b975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,9 +14,12 @@ description = "Probabilistic numerical solvers for differential equations" readme = "README.md" requires-python=">=3.9" classifiers = [ - "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ] dynamic = ["version"] @@ -30,8 +33,8 @@ test =[ "pytest-cases", "pytest-cov", "diffeqzoo", - "diffrax<0.4", - "equinox<0.10.7", + "diffrax", + "equinox", ] lint =[ "pre-commit", diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index f06d7dea..04999e6e 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -9,7 +9,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -23,7 +23,7 @@ def test_fixed_grid_result_matches_adaptive_grid_result(): control = controls.integral_clipped() # Any clipped controller will do. adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2, control=control) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) args = (vf, init) diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index f974c835..232c2898 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -8,7 +8,7 @@ from probdiffeq.solvers import solution, uncalibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -24,7 +24,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(): solver = uncalibrated.solver(strategy) adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) problem_args = (vf, init) diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index fe105936..c9294417 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -9,7 +9,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -26,10 +26,8 @@ def fixture_python_loop_solution(): dt0 = timestep.initial_adaptive( vf, u0, t0=t0, atol=1e-2, rtol=1e-2, error_contraction_rate=5 ) - dt0_ = timestep.initial(lambda y: vf(y, t=t0), u0) - print(dt0, dt0_) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) diff --git a/tests/test_ivpsolve/test_solution_object.py b/tests/test_ivpsolve/test_solution_object.py index 7dad3b95..1c3579ca 100644 --- a/tests/test_ivpsolve/test_solution_object.py +++ b/tests/test_ivpsolve/test_solution_object.py @@ -8,7 +8,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -24,7 +24,7 @@ def fixture_approximate_solution(): adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=1) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=1) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_and_save_every_step( diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index 7099799a..010146ca 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -9,7 +9,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -24,7 +24,7 @@ def fixture_problem_args_kwargs(): solver = calibrated.mle(strategy) adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale=output_scale) diff --git a/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py b/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py index afa5d90f..ab29de75 100644 --- a/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py +++ b/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py @@ -12,14 +12,14 @@ from probdiffeq.solvers import calibrated, solution, uncalibrated from probdiffeq.solvers.strategies import filters, fixedpoint from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @testing.case() def case_solve_fixed_grid(): vf, u0, (t0, t1) = setup.ode() - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = jnp.ones_like(impl.prototypes.output_scale()) kwargs = {"grid": jnp.linspace(t0, t1, endpoint=True, num=5)} @@ -34,7 +34,7 @@ def solver_to_solution(solver): def case_solve_and_save_at(): vf, u0, (t0, t1) = setup.ode() dt0 = timestep.initial(lambda y: vf(y, t=t0), u0) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = jnp.ones_like(impl.prototypes.output_scale()) kwargs = {"save_at": jnp.linspace(t0, t1, endpoint=True, num=5), "dt0": dt0} @@ -52,7 +52,7 @@ def solver_to_solution(solver): def case_solve_and_save_every_step(): vf, u0, (t0, t1) = setup.ode() dt0 = timestep.initial(lambda y: vf(y, t=t0), u0) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = jnp.ones_like(impl.prototypes.output_scale()) kwargs = {"t0": t0, "t1": t1, "dt0": dt0} @@ -70,7 +70,7 @@ def solver_to_solution(solver): def case_simulate_terminal_values(): vf, u0, (t0, t1) = setup.ode() dt0 = timestep.initial(lambda y: vf(y, t=t0), u0) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) output_scale = jnp.ones_like(impl.prototypes.output_scale()) kwargs = {"t0": t0, "t1": t1, "dt0": dt0} diff --git a/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py b/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py index accacd0b..990da098 100644 --- a/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py +++ b/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py @@ -7,7 +7,7 @@ from probdiffeq.solvers import uncalibrated from probdiffeq.solvers.strategies import fixedpoint, smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -21,7 +21,7 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(): adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) with testing.warns(): @@ -40,7 +40,7 @@ def test_warning_for_smoother_in_save_at_mode(): adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) with testing.warns(): diff --git a/tests/test_solvers/test_solution/test_log_marginal_likelihood.py b/tests/test_solvers/test_solution/test_log_marginal_likelihood.py index 68f35be8..180dba9b 100644 --- a/tests/test_solvers/test_solution/test_log_marginal_likelihood.py +++ b/tests/test_solvers/test_solution/test_log_marginal_likelihood.py @@ -8,7 +8,7 @@ from probdiffeq.solvers import solution, uncalibrated from probdiffeq.solvers.strategies import filters, fixedpoint from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -23,7 +23,7 @@ def fixture_sol(): adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) save_at = jnp.linspace(t0, t1, endpoint=True, num=4) @@ -103,7 +103,7 @@ def test_raises_error_for_filter(): solver = uncalibrated.solver(strategy) grid = jnp.linspace(t0, t1, num=3) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver) diff --git a/tests/test_solvers/test_solution/test_log_marginal_likelihood_terminal_values.py b/tests/test_solvers/test_solution/test_log_marginal_likelihood_terminal_values.py index 9c1114b3..1a51e9fe 100644 --- a/tests/test_solvers/test_solution/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_solvers/test_solution/test_log_marginal_likelihood_terminal_values.py @@ -7,7 +7,7 @@ from probdiffeq.solvers import solution, uncalibrated from probdiffeq.solvers.strategies import filters, fixedpoint, smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -37,7 +37,7 @@ def fixture_sol(strategy_func): solver = uncalibrated.solver(strategy) adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.simulate_terminal_values( diff --git a/tests/test_solvers/test_solution/test_offgrid_marginals.py b/tests/test_solvers/test_solution/test_offgrid_marginals.py index 86a81fb3..7b35b312 100644 --- a/tests/test_solvers/test_solution/test_offgrid_marginals.py +++ b/tests/test_solvers/test_solution/test_offgrid_marginals.py @@ -6,7 +6,7 @@ from probdiffeq.solvers import solution, uncalibrated from probdiffeq.solvers.strategies import filters, smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -43,7 +43,7 @@ def test_smoother_marginals_close_to_both_boundaries(): solver = uncalibrated.solver(strategy) output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=4) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) init = solver.initial_condition(tcoeffs, output_scale) grid = jnp.linspace(t0, t1, endpoint=True, num=5) sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver) diff --git a/tests/test_solvers/test_solution/test_sample.py b/tests/test_solvers/test_solution/test_sample.py index 5e561185..b62a53b4 100644 --- a/tests/test_solvers/test_solution/test_sample.py +++ b/tests/test_solvers/test_solution/test_sample.py @@ -8,7 +8,7 @@ from probdiffeq.solvers import markov, uncalibrated from probdiffeq.solvers.strategies import smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -23,7 +23,7 @@ def fixture_approximation(): adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.solve_and_save_every_step( vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1 diff --git a/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py b/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py index 63865138..9e9f15ee 100644 --- a/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py +++ b/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py @@ -9,7 +9,7 @@ from probdiffeq.solvers import calibrated from probdiffeq.solvers.strategies import filters from probdiffeq.solvers.strategies.components import corrections, cubature, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -73,7 +73,7 @@ def fixture_solution(correction_impl): adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1} - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), u0, num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = jnp.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) return ivpsolve.simulate_terminal_values(vf, init, t0=t0, t1=t1, **adaptive_kwargs) diff --git a/tests/test_solvers/test_strategies/test_cubature/test_equivalence.py b/tests/test_solvers/test_strategies/test_cubature/test_equivalence.py index 1933364a..b22458be 100644 --- a/tests/test_solvers/test_strategies/test_cubature/test_equivalence.py +++ b/tests/test_solvers/test_strategies/test_cubature/test_equivalence.py @@ -23,7 +23,6 @@ def test_third_order_spherical_vs_unscented_transform(n=4): tos_points, tos_weights = tos.points, tos.weights_sqrtm ut_points, ut_weights = ut.points, ut.weights_sqrtm for x, y in [(ut_weights, tos_weights), (ut_points, tos_points)]: - print(x, y) assert jnp.allclose(x[:n], y[:n]) assert jnp.allclose(x[n], 0.0) assert jnp.allclose(x[n + 1 :], y[n:]) diff --git a/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py b/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py index 972e421f..6a3d088b 100644 --- a/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py +++ b/tests/test_solvers/test_strategies/test_filter_vs_smoother_rmse.py @@ -9,7 +9,7 @@ from probdiffeq.solvers import uncalibrated from probdiffeq.solvers.strategies import filters, smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -19,7 +19,7 @@ def fixture_solver_setup(): output_scale = jnp.ones_like(impl.prototypes.output_scale()) grid = jnp.linspace(t0, t1, endpoint=True, num=12) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) return {"vf": vf, "tcoeffs": tcoeffs, "grid": grid, "output_scale": output_scale} diff --git a/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py b/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py index 35aac10f..b9063199 100644 --- a/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py +++ b/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py @@ -11,7 +11,7 @@ from probdiffeq.solvers import solution, uncalibrated from probdiffeq.solvers.strategies import fixedpoint, smoothers from probdiffeq.solvers.strategies.components import corrections, priors -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff from tests.setup import setup @@ -21,7 +21,7 @@ def fixture_solver_setup(): output_scale = jnp.ones_like(impl.prototypes.output_scale()) - tcoeffs = autodiff.taylor_mode(lambda y: vf(y, t=t0), (u0,), num=2) + tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) return { "vf": vf, "tcoeffs": tcoeffs, diff --git a/tests/test_solvers/test_taylor/__init__.py b/tests/test_taylor/__init__.py similarity index 100% rename from tests/test_solvers/test_taylor/__init__.py rename to tests/test_taylor/__init__.py diff --git a/tests/test_solvers/test_taylor/data/generate_reference_solutions.py b/tests/test_taylor/data/generate_reference_solutions.py similarity index 83% rename from tests/test_solvers/test_taylor/data/generate_reference_solutions.py rename to tests/test_taylor/data/generate_reference_solutions.py index fea048c8..b44a1f5b 100644 --- a/tests/test_solvers/test_taylor/data/generate_reference_solutions.py +++ b/tests/test_taylor/data/generate_reference_solutions.py @@ -4,7 +4,7 @@ from diffeqzoo import backend from jax.config import config -from probdiffeq.solvers import taylor +from probdiffeq.taylor import autodiff def set_environment(): @@ -29,7 +29,7 @@ def three_body_first(num_derivatives_max=6): def vf(u, *, t, p): # noqa: ARG001 return f(u, *p) - return taylor.taylor_mode_fn( + return autodiff.taylor_mode_unroll( vector_field=vf, initial_values=(u0,), num=num_derivatives_max, @@ -44,7 +44,7 @@ def van_der_pol_second(num_derivatives_max=6): def vf(u, du, *, t, p): # noqa: ARG001 return f(u, du, *p) - return taylor.taylor_mode_fn( + return autodiff.taylor_mode_unroll( vector_field=vf, initial_values=(u0, du0), num=num_derivatives_max, @@ -58,13 +58,11 @@ def vf(u, du, *, t, p): # noqa: ARG001 set_environment() solution1 = three_body_first() - jnp.save( - "./tests/test_solvers/test_taylor/data/three_body_first_solution.npy", solution1 - ) + jnp.save("./tests/test_taylor/data/three_body_first_solution.npy", solution1) solution2 = van_der_pol_second() jnp.save( - "./tests/test_solvers/test_taylor/data/van_der_pol_second_solution.npy", + "./tests/test_taylor/data/van_der_pol_second_solution.npy", solution2, ) diff --git a/tests/test_solvers/test_taylor/data/three_body_first_solution.npy b/tests/test_taylor/data/three_body_first_solution.npy similarity index 100% rename from tests/test_solvers/test_taylor/data/three_body_first_solution.npy rename to tests/test_taylor/data/three_body_first_solution.npy diff --git a/tests/test_solvers/test_taylor/data/van_der_pol_second_solution.npy b/tests/test_taylor/data/van_der_pol_second_solution.npy similarity index 100% rename from tests/test_solvers/test_taylor/data/van_der_pol_second_solution.npy rename to tests/test_taylor/data/van_der_pol_second_solution.npy diff --git a/tests/test_solvers/test_taylor/test_affine_recursion.py b/tests/test_taylor/test_affine_recursion.py similarity index 87% rename from tests/test_solvers/test_taylor/test_affine_recursion.py rename to tests/test_taylor/test_affine_recursion.py index f2433f6e..7fbac47b 100644 --- a/tests/test_solvers/test_taylor/test_affine_recursion.py +++ b/tests/test_taylor/test_affine_recursion.py @@ -3,7 +3,7 @@ import jax.numpy as jnp from probdiffeq.backend import testing -from probdiffeq.solvers.taylor import affine, autodiff +from probdiffeq.taylor import affine, autodiff @testing.parametrize("num", [1, 2, 4]) @@ -30,5 +30,5 @@ def vf(x, /): init = (jnp.arange(9.0, 11.0),) - solution = autodiff.taylor_mode(vf, init, num=n) + solution = autodiff.taylor_mode_scan(vf, init, num=n) return vf, init, solution diff --git a/tests/test_solvers/test_taylor/test_exact_first_order.py b/tests/test_taylor/test_exact_first_order.py similarity index 84% rename from tests/test_solvers/test_taylor/test_exact_first_order.py rename to tests/test_taylor/test_exact_first_order.py index ab6f067a..07c6a4e5 100644 --- a/tests/test_solvers/test_taylor/test_exact_first_order.py +++ b/tests/test_taylor/test_exact_first_order.py @@ -4,17 +4,17 @@ import jax.numpy as jnp from probdiffeq.backend import testing -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff @testing.case() -def case_forward_mode(): - return autodiff.forward_mode +def case_forward_mode_recursive(): + return autodiff.forward_mode_recursive @testing.case() -def case_taylor_mode(): - return autodiff.taylor_mode +def case_taylor_mode_scan(): + return autodiff.taylor_mode_scan @testing.case() @@ -29,9 +29,7 @@ def fixture_pb_with_solution(): def vf(u, /): return f(u, *f_args) - solution = jnp.load( - "./tests/test_solvers/test_taylor/data/three_body_first_solution.npy" - ) + solution = jnp.load("./tests/test_taylor/data/three_body_first_solution.npy") return (vf, (u0,)), solution diff --git a/tests/test_solvers/test_taylor/test_exact_higher_order.py b/tests/test_taylor/test_exact_higher_order.py similarity index 75% rename from tests/test_solvers/test_taylor/test_exact_higher_order.py rename to tests/test_taylor/test_exact_higher_order.py index 141bb12f..26049560 100644 --- a/tests/test_solvers/test_taylor/test_exact_higher_order.py +++ b/tests/test_taylor/test_exact_higher_order.py @@ -4,17 +4,17 @@ import jax.numpy as jnp from probdiffeq.backend import testing -from probdiffeq.solvers.taylor import autodiff +from probdiffeq.taylor import autodiff @testing.case() -def case_forward_mode(): - return autodiff.forward_mode +def case_forward_mode_recursive(): + return autodiff.forward_mode_recursive @testing.case() -def case_taylor_mode(): - return autodiff.taylor_mode +def case_taylor_mode_scan(): + return autodiff.taylor_mode_scan @testing.fixture(name="pb_with_solution") @@ -24,9 +24,7 @@ def fixture_pb_with_solution(): def vf(u, du, /): return f(u, du, *f_args) - solution = jnp.load( - "./tests/test_solvers/test_taylor/data/van_der_pol_second_solution.npy" - ) + solution = jnp.load("./tests/test_taylor/data/van_der_pol_second_solution.npy") return (vf, (u0, du0)), solution diff --git a/tests/test_solvers/test_taylor/test_inexact_first_order.py b/tests/test_taylor/test_inexact_first_order.py similarity index 92% rename from tests/test_solvers/test_taylor/test_inexact_first_order.py rename to tests/test_taylor/test_inexact_first_order.py index 596d5f41..ce9e06fb 100644 --- a/tests/test_solvers/test_taylor/test_inexact_first_order.py +++ b/tests/test_taylor/test_inexact_first_order.py @@ -4,7 +4,7 @@ from probdiffeq.backend import testing from probdiffeq.impl import impl -from probdiffeq.solvers.taylor import autodiff, estim +from probdiffeq.taylor import autodiff, estim @testing.case() @@ -21,7 +21,7 @@ def fixture_pb_with_solution(): def vf(u, /): return f(u, *f_args) - solution = autodiff.taylor_mode(vf, (u0,), num=3) + solution = autodiff.taylor_mode_scan(vf, (u0,), num=3) return (vf, (u0,), t0), solution diff --git a/tests/test_util/test_ibm_discrete.py b/tests/test_util/test_ibm_discrete.py index 5a0b18fd..fe1ac9c5 100644 --- a/tests/test_util/test_ibm_discrete.py +++ b/tests/test_util/test_ibm_discrete.py @@ -39,7 +39,6 @@ def step(rv, model): _, rvs = jax.lax.scan(step, init=init, xs=transitions, reverse=False) means = impl.stats.mean(rvs) - print(jax.tree_util.tree_map(jnp.shape, rvs)) # todo: does this conflict with error estimation? stds = impl.stats.standard_deviation(rvs)