Skip to content

Commit

Permalink
Taylor series benchmark (#655)
Browse files Browse the repository at this point in the history
* Taylor-series estimator draft

* Pleiades benchmark updates

* Run a bigger benchmark

* Moved pleiades benchmark folder around

* Cleaned up makefile

* Include new benchmark in docs

* Rerun benchmark for docs
  • Loading branch information
pnkraemer authored Oct 3, 2023
1 parent 0145fc7 commit 6af2833
Show file tree
Hide file tree
Showing 11 changed files with 2,992 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ docs/benchmarks/hires/*.npy
docs/benchmarks/pleiades/*.npy
docs/benchmarks/vanderpol/*.npy
docs/benchmarks/lotkavolterra/*.npy
docs/benchmarks/taylor_pleiades/*.npy

# IDE stuff
.idea/
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/hires/plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "2f928018",
"metadata": {},
"source": [
"# High Irradiance Response (HIRES)\n",
"# Hires\n",
"\n",
"The HIRES problem is a common stiff differential equation."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/hires/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jupyter:
name: python3
---

# High Irradiance Response (HIRES)
# Hires

The HIRES problem is a common stiff differential equation.

Expand Down
2,700 changes: 2,700 additions & 0 deletions docs/benchmarks/taylor_pleiades/plot.ipynb

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions docs/benchmarks/taylor_pleiades/plot.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
---
jupyter:
jupytext:
formats: ipynb,md
text_representation:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Taylor-series: Pleiades

The Pleiades problem is a common non-stiff differential equation.

```python
"""Benchmark all Taylor-series estimators on the Pleiades problem."""

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.config import config

from probdiffeq.util.doc_util import notebook

config.update("jax_platform_name", "cpu")
```

```python
def load_results():
"""Load the results from a file."""
return jnp.load("./results.npy", allow_pickle=True)[()]


def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "taylor" in label.lower():
return {"color": "C0", "linestyle": "solid"}
if "forward" in label.lower():
return {"color": "C1", "linestyle": "dashed"}
msg = f"Label {label} unknown."
raise ValueError(msg)


def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
for label, wp in results.items():
style = choose_style(label)

inputs = wp["arguments"]
work_mean = wp["work_compile"]
axis_compile.semilogy(inputs, work_mean, label=label, **style)

work_mean, work_std = (wp["work_mean"], wp["work_std"])
range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, label=label, **style)
axis_perform.fill_between(inputs, range_lower, range_upper, alpha=0.3, **style)

return axis_compile, axis_perform
```

```python
plt.rcParams.update(notebook.plot_config())

fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, dpi=150, figsize=(8, 3), sharex=True, tight_layout=True
)
fig.suptitle("Pleiades problem, Taylor-series estimation")

results = load_results()
axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)

axis_compile.set_title("Compile time")
axis_perform.set_title("Evaluation time")
axis_perform.legend()
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid()
axis_compile.grid()

plt.show()
```

```python

```
Binary file added docs/benchmarks/taylor_pleiades/results.npy
Binary file not shown.
172 changes: 172 additions & 0 deletions docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Benchmark the initialisation methods on Pleiades.
See makefile for instructions.
"""
import argparse
import functools
import os
import statistics
import time
import timeit
from typing import Callable

import jax
import jax.numpy as jnp
from jax import config

from probdiffeq.impl import impl
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.util.doc_util import info


def set_jax_config() -> None:
"""Set JAX and other external libraries up."""
# x64 precision
config.update("jax_enable_x64", True)

# CPU
config.update("jax_platform_name", "cpu")


def set_probdiffeq_config() -> None:
"""Set probdiffeq up."""
impl.select("isotropic", ode_shape=(14,))


def print_library_info() -> None:
"""Print the environment info for this benchmark."""
info.print_info()
print("\n------------------------------------------\n")


def parse_arguments() -> argparse.Namespace:
"""Parse the arguments from the command line."""
parser = argparse.ArgumentParser()
parser.add_argument("--max_time", type=float)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
return parser.parse_args()


def timeit_fun_from_args(arguments: argparse.Namespace, /) -> Callable:
"""Construct a timeit-function from the command-line arguments."""

def timer(fun, /):
return list(timeit.repeat(fun, number=1, repeat=arguments.repeats))

return timer


def taylor_mode() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0, du0) = _pleiades()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def forward_mode() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0, du0) = _pleiades()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def _pleiades():
# fmt: off
u0 = jnp.asarray(
[
3.0, 3.0, -1.0, -3.00, 2.0, -2.00, 2.0,
3.0, -3.0, 2.0, 0.00, 0.0, -4.00, 4.0,
]
)
du0 = jnp.asarray(
[
0.0, 0.0, 0.0, 0.00, 0.0, 1.75, -1.5,
0.0, 0.0, 0.0, -1.25, 1.0, 0.00, 0.0,
]
)
# fmt: on
t0 = 0.0

@jax.jit
def vf_probdiffeq(u, du, *, t=t0): # noqa: ARG001
"""Pleiades problem."""
x = u[0:7] # x
y = u[7:14] # y
xi, xj = x[:, None], x[None, :]
yi, yj = y[:, None], y[None, :]
rij = ((xi - xj) ** 2 + (yi - yj) ** 2) ** (3 / 2)
mj = jnp.arange(1, 8)[None, :]
ddx = jnp.sum(jnp.nan_to_num(mj * (xj - xi) / rij), axis=1)
ddy = jnp.sum(jnp.nan_to_num(mj * (yj - yi) / rij), axis=1)
return jnp.concatenate((ddx, ddy))

return vf_probdiffeq, (u0, du0)


def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
"""Benchmark a function iteratively until a max-time threshold is exceeded."""
work_compile = []
work_mean = []
work_std = []
arguments = []

t0 = time.perf_counter()
arg = 1
while (elapsed := time.perf_counter() - t0) < max_time:
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
t0 = time.perf_counter()
_ = fun(arg)
t1 = time.perf_counter()
time_compile = t1 - t0

time_execute = timeit_fun(lambda: fun(arg)) # noqa: B023

arguments.append(arg + 1) # plus one, because second-order problem
work_compile.append(time_compile)
work_mean.append(statistics.mean(time_execute))
work_std.append(statistics.stdev(time_execute))
arg += 1
return {
"work_mean": jnp.asarray(work_mean),
"work_std": jnp.asarray(work_std),
"work_compile": jnp.asarray(work_compile),
"arguments": jnp.asarray(arguments),
}


if __name__ == "__main__":
set_jax_config()
algorithms = {
r"Forward-mode": forward_mode(),
r"Taylor-mode": taylor_mode(),
}

# Compute a reference solution
args = parse_arguments()
timeit_fun = timeit_fun_from_args(args)

# Compute all work-precision diagrams
results = {}
for label, algo in algorithms.items():
print("\n")
print(label)
results[label] = adaptive_benchmark(
algo, timeit_fun=timeit_fun, max_time=args.max_time
)
# Save results
if args.save:
jnp.save(os.path.dirname(__file__) + "/results.npy", results)
print("\nSaving successful.\n")
else:
print("\nSkipped saving.\n")
2 changes: 1 addition & 1 deletion docs/benchmarks/vanderpol/plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "2f928018",
"metadata": {},
"source": [
"# Van der Pol\n",
"# Stiff Van-der-Pol\n",
"\n",
"The van der Pol problem is a common stiff differential equation."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/vanderpol/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jupyter:
name: python3
---

# Van der Pol
# Stiff Van-der-Pol

The van der Pol problem is a common stiff differential equation.

Expand Down
21 changes: 15 additions & 6 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ format:
jupytext --quiet --sync docs/benchmarks/pleiades/*.ipynb
jupytext --quiet --sync docs/benchmarks/vanderpol/*.ipynb
jupytext --quiet --sync docs/benchmarks/lotkavolterra/*.ipynb
jupytext --quiet --sync docs/benchmarks/taylor_pleiades/*.ipynb

lint:
pre-commit run --all-files
Expand All @@ -31,6 +32,9 @@ example:
jupytext --quiet --sync docs/examples_parameter_estimation/*

run-benchmarks:
time python docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py --max_time 15 --repeats 5 --save
jupytext --quiet --sync docs/benchmarks/taylor_pleiades/*.ipynb
jupytext --quiet --execute docs/benchmarks/taylor_pleiades/*.ipynb
time python docs/benchmarks/lotkavolterra/run_lotkavolterra.py --start 3 --stop 12 --repeats 20 --save
jupytext --quiet --sync docs/benchmarks/lotkavolterra/*.ipynb
jupytext --quiet --execute docs/benchmarks/lotkavolterra/*.ipynb
Expand All @@ -45,6 +49,9 @@ run-benchmarks:
jupytext --quiet --execute docs/benchmarks/hires/*.ipynb

dry-run-benchmarks:
time python docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py --max_time 0.5 --repeats 2 --no-save
jupytext --quiet --sync docs/benchmarks/taylor_pleiades/*.ipynb
jupytext --quiet --execute docs/benchmarks/taylor_pleiades/*.ipynb
time python docs/benchmarks/lotkavolterra/run_lotkavolterra.py --start 3 --stop 5 --repeats 2 --no-save
time python docs/benchmarks/vanderpol/run_vanderpol.py --start 1 --stop 3 --repeats 2 --no-save
time python docs/benchmarks/pleiades/run_pleiades.py --start 3 --stop 5 --repeats 2 --no-save
Expand All @@ -57,14 +64,16 @@ clean:
rm -rf *.egg-info
rm -rf dist site build
rm -rf *.ipynb_checkpoints
rm -rf docs/examples_benchmarks/benchmarks/lotka_volterra/__pycache__
rm -rf docs/examples_benchmarks/benchmarks/lotka_volterra/.ipynb_checkpoints
rm -rf docs/examples_benchmarks/benchmarks/pleiades/__pycache__
rm -rf docs/examples_benchmarks/benchmarks/pleiades/.ipynb_checkpoints
rm -rf docs/examples_benchmarks/benchmarks/stiff_van_der_pol/__pycache__
rm -rf docs/examples_benchmarks/benchmarks/stiff_van_der_pol/.ipynb_checkpoints
rm -rf docs/benchmarks/hires/__pycache__
rm -rf docs/benchmarks/hires/.ipynb_checkpoints
rm -rf docs/benchmarks/pleiades/__pycache__
rm -rf docs/benchmarks/pleiades/.ipynb_checkpoints
rm -rf docs/benchmarks/lotkavolterra/__pycache__
rm -rf docs/benchmarks/lotkavolterra/.ipynb_checkpoints
rm -rf docs/benchmarks/vanderpol/__pycache__
rm -rf docs/benchmarks/vanderpol/.ipynb_checkpoints
rm -rf docs/benchmarks/taylor_pleiades/__pycache__
rm -rf docs/benchmarks/taylor_pleiades/.ipynb_checkpoints
rm docs/benchmarks/hires/*.npy

doc:
Expand Down
16 changes: 10 additions & 6 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ plugins:
- benchmarks/lotkavolterra/*.md
- benchmarks/lotkavolterra/*.py
- benchmarks/lotkavolterra/.ipynb_checkpoints/*
- benchmarks/taylor_pleiades/*.md
- benchmarks/taylor_pleiades/*.py
- benchmarks/taylor_pleiades/.ipynb_checkpoints/*
- mkdocs-jupyter
extra:
social:
Expand Down Expand Up @@ -120,7 +123,7 @@ nav:
- API DOCUMENTATION:
- ivpsolve: api_docs/ivpsolve.md
- adaptive: api_docs/adaptive.md
- api_docs/controls.md
- controls: api_docs/controls.md
- timestep: api_docs/timestep.md
- impl: api_docs/impl.md
- solvers:
Expand All @@ -147,8 +150,9 @@ nav:
- dev_docs/public_api.md
- dev_docs/creating_example_notebook.md
- dev_docs/continuous_integration.md
- Benchmarks:
- Lotka-Volterra: benchmarks/lotkavolterra/plot.ipynb
- Van-der-Pol: benchmarks/vanderpol/plot.ipynb
- Pleiades: benchmarks/pleiades/plot.ipynb
- Hires: benchmarks/hires/plot.ipynb
- BENCHMARKS:
- benchmarks/lotkavolterra/plot.ipynb
- benchmarks/vanderpol/plot.ipynb
- benchmarks/pleiades/plot.ipynb
- benchmarks/hires/plot.ipynb
- benchmarks/taylor_pleiades/plot.ipynb

0 comments on commit 6af2833

Please sign in to comment.