Skip to content

Commit

Permalink
Warning about compilation times (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer authored Mar 15, 2023
1 parent f6b6bb7 commit 2516e8c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
32 changes: 32 additions & 0 deletions docs/quickstart/troubleshooting.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Troubleshooting

## Long compilation times

If a solution routine takes surprisingly long to compile but then executes quickly,
it may be due to the choice of Taylor-coefficient computation.
Automatic-differentiation-based routines such as Taylor-mode or forward-mode Taylor-series
estimation JIT-compile the ODE vector field $\nu$, respectively $\nu(\nu+1)/2$ times
for $\nu$ derivatives in the state-space model (commonly referred to as `num_derivatives`).
On top of this, the vector field is compiled a final time for the "actual" simulation.

As a solution, either reduce the number of derivatives
(if that is appropriate for your integration problem)
or switch to a different Taylor-coefficient routine.
For example, use
```python
simulate_terminal_values(..., taylor_fn=taylor.make_runge_kutta_starter_fn())
solve_and_save_at(..., taylor_fn=taylor.make_runge_kutta_starter_fn())
# etc.
```
instead of
```python
simulate_terminal_values(..., taylor_fn=taylor.taylor_mode_fn)
solve_and_save_at(..., taylor_fn=taylor.taylor_mode_fn)
# etc.
```
For $\nu < 5$, switching to Runge-Kutta starters should preserve performance of the solvers.
High-order methods, e.g. $\nu = 9$ are only possible with `taylor_fn=taylor.taylor_mode_fn`.


## Other problems
Your problem is not discussed here? Please open an issue!
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ nav:
- Quickstart:
- "quickstart/quickstart.ipynb"
- "quickstart/transitioning_from_other_packages.md"
- "quickstart/troubleshooting.md"
- Examples:
- "examples/posterior_uncertainties.ipynb"
- "examples/second_order_problems.ipynb"
Expand Down
29 changes: 26 additions & 3 deletions probdiffeq/taylor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
def taylor_mode_fn(
*, vector_field: Callable, initial_values: Tuple, num: int, t, parameters
):
"""Taylor-expand the solution of an IVP with Taylor-mode differentiation."""
"""Taylor-expand the solution of an IVP with Taylor-mode differentiation.
!!! warning "Compilation time"
JIT-compiling this function unrolls a loop of length `num`
and JIT-compiles the `vector_field` exactly `num` times.
"""
# Number of positional arguments in f
num_arguments = len(initial_values)

Expand Down Expand Up @@ -59,7 +65,15 @@ def mask(i):
def forward_mode_fn(
*, vector_field: Callable, initial_values: Tuple, num: int, t, parameters
):
"""Taylor-expand the solution of an IVP with forward-mode differentiation."""
"""Taylor-expand the solution of an IVP with forward-mode differentiation.
!!! warning "Compilation time"
JIT-compiling this function unrolls a loop of length `num`
and JIT-compiles the `vector_field` exactly `num(num+1)/2` times.
"""
vf = jax.tree_util.Partial(vector_field, t=t, p=parameters)

g_n, g_0 = vf, vf
Expand Down Expand Up @@ -88,7 +102,12 @@ def df(*args):
def affine_recursion(
*, vector_field: Callable, initial_values: Tuple, num: int, t, parameters
):
"""Evaluate the Taylor series of an affine differential equation."""
"""Evaluate the Taylor series of an affine differential equation.
!!! warning "Compilation time"
JIT-compiling this function unrolls a loop of length `num`.
"""
if num == 0:
return initial_values

Expand Down Expand Up @@ -203,6 +222,10 @@ def taylor_mode_doubling_fn(
It might be deleted tomorrow
and without any deprecation policy.
!!! warning "Compilation time"
JIT-compiling this function unrolls a loop of length `num`
and JIT-compiles the `vector_field` O(log(`num`)) times.
"""
vf = jax.tree_util.Partial(vector_field, t=t, p=parameters)
(u0,) = initial_values
Expand Down

0 comments on commit 2516e8c

Please sign in to comment.