From 2516e8c708b5aa7f35254ddc6c027a262b78d7cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 15 Mar 2023 18:57:17 +0100 Subject: [PATCH] Warning about compilation times (#467) --- docs/quickstart/troubleshooting.md | 32 ++++++++++++++++++++++++++++++ mkdocs.yml | 1 + probdiffeq/taylor.py | 29 ++++++++++++++++++++++++--- 3 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 docs/quickstart/troubleshooting.md diff --git a/docs/quickstart/troubleshooting.md b/docs/quickstart/troubleshooting.md new file mode 100644 index 00000000..d197b8c6 --- /dev/null +++ b/docs/quickstart/troubleshooting.md @@ -0,0 +1,32 @@ +# Troubleshooting + +## Long compilation times + +If a solution routine takes surprisingly long to compile but then executes quickly, +it may be due to the choice of Taylor-coefficient computation. +Automatic-differentiation-based routines such as Taylor-mode or forward-mode Taylor-series +estimation JIT-compile the ODE vector field $\nu$, respectively $\nu(\nu+1)/2$ times +for $\nu$ derivatives in the state-space model (commonly referred to as `num_derivatives`). +On top of this, the vector field is compiled a final time for the "actual" simulation. + +As a solution, either reduce the number of derivatives +(if that is appropriate for your integration problem) +or switch to a different Taylor-coefficient routine. +For example, use +```python +simulate_terminal_values(..., taylor_fn=taylor.make_runge_kutta_starter_fn()) +solve_and_save_at(..., taylor_fn=taylor.make_runge_kutta_starter_fn()) +# etc. +``` +instead of +```python +simulate_terminal_values(..., taylor_fn=taylor.taylor_mode_fn) +solve_and_save_at(..., taylor_fn=taylor.taylor_mode_fn) +# etc. +``` +For $\nu < 5$, switching to Runge-Kutta starters should preserve performance of the solvers. +High-order methods, e.g. $\nu = 9$ are only possible with `taylor_fn=taylor.taylor_mode_fn`. + + +## Other problems +Your problem is not discussed here? Please open an issue! \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 5cf580fd..bacca7bc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - Quickstart: - "quickstart/quickstart.ipynb" - "quickstart/transitioning_from_other_packages.md" + - "quickstart/troubleshooting.md" - Examples: - "examples/posterior_uncertainties.ipynb" - "examples/second_order_problems.ipynb" diff --git a/probdiffeq/taylor.py b/probdiffeq/taylor.py index 21591298..14cc858e 100644 --- a/probdiffeq/taylor.py +++ b/probdiffeq/taylor.py @@ -15,7 +15,13 @@ def taylor_mode_fn( *, vector_field: Callable, initial_values: Tuple, num: int, t, parameters ): - """Taylor-expand the solution of an IVP with Taylor-mode differentiation.""" + """Taylor-expand the solution of an IVP with Taylor-mode differentiation. + + !!! warning "Compilation time" + JIT-compiling this function unrolls a loop of length `num` + and JIT-compiles the `vector_field` exactly `num` times. + + """ # Number of positional arguments in f num_arguments = len(initial_values) @@ -59,7 +65,15 @@ def mask(i): def forward_mode_fn( *, vector_field: Callable, initial_values: Tuple, num: int, t, parameters ): - """Taylor-expand the solution of an IVP with forward-mode differentiation.""" + """Taylor-expand the solution of an IVP with forward-mode differentiation. + + !!! warning "Compilation time" + JIT-compiling this function unrolls a loop of length `num` + and JIT-compiles the `vector_field` exactly `num(num+1)/2` times. + + + + """ vf = jax.tree_util.Partial(vector_field, t=t, p=parameters) g_n, g_0 = vf, vf @@ -88,7 +102,12 @@ def df(*args): def affine_recursion( *, vector_field: Callable, initial_values: Tuple, num: int, t, parameters ): - """Evaluate the Taylor series of an affine differential equation.""" + """Evaluate the Taylor series of an affine differential equation. + + !!! warning "Compilation time" + JIT-compiling this function unrolls a loop of length `num`. + + """ if num == 0: return initial_values @@ -203,6 +222,10 @@ def taylor_mode_doubling_fn( It might be deleted tomorrow and without any deprecation policy. + !!! warning "Compilation time" + JIT-compiling this function unrolls a loop of length `num` + and JIT-compiles the `vector_field` O(log(`num`)) times. + """ vf = jax.tree_util.Partial(vector_field, t=t, p=parameters) (u0,) = initial_values