Skip to content

Commit

Permalink
Taylor series benchmark: Fitzhugh-Nagumo (#656)
Browse files Browse the repository at this point in the history
* Normalised the colors

* Added a benchmark for fitzhugn-nagumo

* Mention benchmark creation in hte dev docs

* Mention new benchmark in makefile

* Rerun benchmark

* FHN benchmark config

* FHN results improved
  • Loading branch information
pnkraemer authored Oct 4, 2023
1 parent 6af2833 commit 7ea9fe0
Show file tree
Hide file tree
Showing 8 changed files with 2,982 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ docs/benchmarks/pleiades/*.npy
docs/benchmarks/vanderpol/*.npy
docs/benchmarks/lotkavolterra/*.npy
docs/benchmarks/taylor_pleiades/*.npy
docs/benchmarks/taylor_fitzhughnagumo/*.npy

# IDE stuff
.idea/
Expand Down
2,682 changes: 2,682 additions & 0 deletions docs/benchmarks/taylor_fitzhughnagumo/plot.ipynb

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions docs/benchmarks/taylor_fitzhughnagumo/plot.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
---
jupyter:
jupytext:
formats: ipynb,md
text_representation:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Taylor-series: FitzHugh-Nagumo

The FHN problem is a common non-stiff differential equation.

```python
"""Benchmark all Taylor-series estimators on the Fitzhugh-Nagumo problem."""

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.config import config

from probdiffeq.util.doc_util import notebook

config.update("jax_platform_name", "cpu")
```

```python
def load_results():
"""Load the results from a file."""
return jnp.load("./results.npy", allow_pickle=True)[()]


def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "doubling" in label.lower():
return {"color": "C2", "linestyle": "dotted"}
if "taylor" in label.lower():
return {"color": "C0", "linestyle": "solid"}
if "forward" in label.lower():
return {"color": "C1", "linestyle": "dashed"}
msg = f"Label {label} unknown."
raise ValueError(msg)


def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
for label, wp in results.items():
style = choose_style(label)

inputs = wp["arguments"]
work_mean = wp["work_compile"]
axis_compile.semilogy(inputs, work_mean, label=label, **style)

work_mean, work_std = wp["work_mean"], wp["work_std"]
work_min, work_max = (wp["work_min"], wp["work_max"])
range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, label=label, **style)
axis_perform.fill_between(inputs, range_lower, range_upper, alpha=0.3, **style)

return axis_compile, axis_perform
```

```python
plt.rcParams.update(notebook.plot_config())

fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, dpi=150, figsize=(8, 3), sharex=True, tight_layout=True
)
fig.suptitle("FitzHugh-Nagumo problem, Taylor-series estimation")

results = load_results()
axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)

axis_compile.set_title("Compile time")
axis_perform.set_title("Evaluation time")
axis_perform.legend()
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid()
axis_compile.grid()

plt.show()
```

```python

```
Binary file added docs/benchmarks/taylor_fitzhughnagumo/results.npy
Binary file not shown.
175 changes: 175 additions & 0 deletions docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Benchmark the initialisation methods on the FitzHugh-Nagumo problem.
See makefile for instructions.
"""
import argparse
import functools
import os
import statistics
import time
import timeit
from typing import Callable

import jax
import jax.numpy as jnp
from jax import config

from probdiffeq.impl import impl
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.util.doc_util import info


def set_jax_config() -> None:
"""Set JAX and other external libraries up."""
# x64 precision
config.update("jax_enable_x64", True)

# CPU
config.update("jax_platform_name", "cpu")


def set_probdiffeq_config() -> None:
"""Set probdiffeq up."""
impl.select("isotropic", ode_shape=(14,))


def print_library_info() -> None:
"""Print the environment info for this benchmark."""
info.print_info()
print("\n------------------------------------------\n")


def parse_arguments() -> argparse.Namespace:
"""Parse the arguments from the command line."""
parser = argparse.ArgumentParser()
parser.add_argument("--max_time", type=float)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
return parser.parse_args()


def timeit_fun_from_args(arguments: argparse.Namespace, /) -> Callable:
"""Construct a timeit-function from the command-line arguments."""

def timer(fun, /):
return list(timeit.repeat(fun, number=1, repeat=arguments.repeats))

return timer


def taylor_mode() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def taylor_mode_doubling() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def forward_mode() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def _fitzhugh_nagumo():
u0 = jnp.asarray([-1.0, 1.0])

@jax.jit
def vf_probdiffeq(u, a=0.2, b=0.2, c=3.0):
"""FitzHugh--Nagumo model."""
du1 = c * (u[0] - u[0] ** 3 / 3 + u[1])
du2 = -(1.0 / c) * (u[0] - a - b * u[1])
return jnp.asarray([du1, du2])

return vf_probdiffeq, (u0,)


def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
"""Benchmark a function iteratively until a max-time threshold is exceeded."""
work_compile = []
work_mean = []
work_std = []
work_min = []
work_median = []
work_max = []
arguments = []

t0 = time.perf_counter()
arg = 1
while (elapsed := time.perf_counter() - t0) < max_time:
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
t0 = time.perf_counter()
_ = fun(arg)
t1 = time.perf_counter()
time_compile = t1 - t0

time_execute = timeit_fun(lambda: fun(arg)) # noqa: B023

arguments.append(arg + 1) # plus one, because second-order problem
work_compile.append(time_compile)
work_mean.append(statistics.mean(time_execute))
work_std.append(statistics.stdev(time_execute))
work_min.append(min(time_execute))
work_median.append(statistics.median(time_execute))
work_max.append(max(time_execute))
arg += 1
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
return {
"work_median": jnp.asarray(work_median),
"work_mean": jnp.asarray(work_mean),
"work_std": jnp.asarray(work_std),
"work_min": jnp.asarray(work_min),
"work_max": jnp.asarray(work_max),
"work_compile": jnp.asarray(work_compile),
"arguments": jnp.asarray(arguments),
}


if __name__ == "__main__":
set_jax_config()
algorithms = {
r"Forward-mode": forward_mode(),
r"Taylor-mode": taylor_mode(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
}

# Compute a reference solution
args = parse_arguments()
timeit_fun = timeit_fun_from_args(args)

# Compute all work-precision diagrams
results = {}
for label, algo in algorithms.items():
print("\n")
print(label)
results[label] = adaptive_benchmark(
algo, timeit_fun=timeit_fun, max_time=args.max_time
)
# Save results
if args.save:
jnp.save(os.path.dirname(__file__) + "/results.npy", results)
print("\nSaving successful.\n")
else:
print("\nSkipped saving.\n")
16 changes: 15 additions & 1 deletion docs/dev_docs/creating_example_notebook.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Creating an example notebook

To create a new example notebook or benchmark and include it in the documentation, follow the steps:

## Tutorial


To create a new example notebook and include it in the documentation, follow the steps:

1. Create a jupyter notebook, preferably in `docs/examples_*/` and fill it with content.
In case you are wondering which subfolder is most appropriate:
Expand All @@ -11,3 +15,13 @@ To create a new example notebook or benchmark and include it in the documentatio
3. Include the notebook into the docs by mentioning it in the `nav` section of `mkdocs.yml`
4. Update the makefile to enjoy formatting and linting
5. Enjoy.


## Benchmark

1. Create a new folder in the `docs/benchmarks/` directory
2. Create the benchmark script. Usually, the execution is in a python script and the plotting in a jupyter notebook.
3. Link the (plotting-)notebook to a markdown file (for better version control).
4. Include the (plotting-)notebook into the docs via `mkdocs.yml`. Mention the markdown and python script in the same folder under `mkdocs.yml -> exclude`
5. Mention the new benchmark in the makefile (`clean`, `format`, `run-benchmarks`, `dry-run-benchmarks`). A dry-run is for checking that the code functions properly. The benchmark run itself should not take less than a minute, otherwise the whole benchmark suite grows out of hand.
6. Mention the juypter caches and potential data-storage in the `.gitignore`
11 changes: 8 additions & 3 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ format:
jupytext --quiet --sync docs/benchmarks/vanderpol/*.ipynb
jupytext --quiet --sync docs/benchmarks/lotkavolterra/*.ipynb
jupytext --quiet --sync docs/benchmarks/taylor_pleiades/*.ipynb
jupytext --quiet --sync docs/benchmarks/taylor_fitzhughnagumo/*.ipynb

lint:
pre-commit run --all-files
Expand All @@ -32,6 +33,9 @@ example:
jupytext --quiet --sync docs/examples_parameter_estimation/*

run-benchmarks:
time python docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py --max_time 20 --repeats 15 --save
jupytext --quiet --sync docs/benchmarks/taylor_fitzhughnagumo/*.ipynb
jupytext --quiet --execute docs/benchmarks/taylor_fitzhughnagumo/*.ipynb
time python docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py --max_time 15 --repeats 5 --save
jupytext --quiet --sync docs/benchmarks/taylor_pleiades/*.ipynb
jupytext --quiet --execute docs/benchmarks/taylor_pleiades/*.ipynb
Expand All @@ -49,9 +53,9 @@ run-benchmarks:
jupytext --quiet --execute docs/benchmarks/hires/*.ipynb

dry-run-benchmarks:
time python docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py --max_time 0.5 --repeats 2 --no-save
time python docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py --max_time 0.5 --repeats 2 --no-save
jupytext --quiet --sync docs/benchmarks/taylor_pleiades/*.ipynb
jupytext --quiet --execute docs/benchmarks/taylor_pleiades/*.ipynb
time python docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py --max_time 0.5 --repeats 2 --no-save
time python docs/benchmarks/lotkavolterra/run_lotkavolterra.py --start 3 --stop 5 --repeats 2 --no-save
time python docs/benchmarks/vanderpol/run_vanderpol.py --start 1 --stop 3 --repeats 2 --no-save
time python docs/benchmarks/pleiades/run_pleiades.py --start 3 --stop 5 --repeats 2 --no-save
Expand All @@ -74,7 +78,8 @@ clean:
rm -rf docs/benchmarks/vanderpol/.ipynb_checkpoints
rm -rf docs/benchmarks/taylor_pleiades/__pycache__
rm -rf docs/benchmarks/taylor_pleiades/.ipynb_checkpoints
rm docs/benchmarks/hires/*.npy
rm -rf docs/benchmarks/taylor_fitzhughnagumo/__pycache__
rm -rf docs/benchmarks/taylor_fitzhughnagumo/.ipynb_checkpoints

doc:
mkdocs build
12 changes: 8 additions & 4 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ theme:
- search.highlight
palette:
# Palette toggle for light mode
- scheme: probdiffeq-light
- scheme: default
primary: white
accent: amber
toggle:
icon: material/brightness-7
icon: material/eye
name: Switch to dark mode

# Palette toggle for dark mode
- scheme: probdiffeq-dark
- scheme: slate
primary: black
accent: amber
toggle:
icon: material/brightness-4
icon: material/eye-outline
name: Switch to light mode
icon:
repo: fontawesome/brands/github
Expand Down

0 comments on commit 7ea9fe0

Please sign in to comment.