-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Taylor series benchmark: Fitzhugh-Nagumo (#656)
* Normalised the colors * Added a benchmark for fitzhugn-nagumo * Mention benchmark creation in hte dev docs * Mention new benchmark in makefile * Rerun benchmark * FHN benchmark config * FHN results improved
- Loading branch information
Showing
8 changed files
with
2,982 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
--- | ||
jupyter: | ||
jupytext: | ||
formats: ipynb,md | ||
text_representation: | ||
extension: .md | ||
format_name: markdown | ||
format_version: '1.3' | ||
jupytext_version: 1.15.0 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
# Taylor-series: FitzHugh-Nagumo | ||
|
||
The FHN problem is a common non-stiff differential equation. | ||
|
||
```python | ||
"""Benchmark all Taylor-series estimators on the Fitzhugh-Nagumo problem.""" | ||
|
||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
from jax.config import config | ||
|
||
from probdiffeq.util.doc_util import notebook | ||
|
||
config.update("jax_platform_name", "cpu") | ||
``` | ||
|
||
```python | ||
def load_results(): | ||
"""Load the results from a file.""" | ||
return jnp.load("./results.npy", allow_pickle=True)[()] | ||
|
||
|
||
def choose_style(label): | ||
"""Choose a plotting style for a given algorithm.""" | ||
if "doubling" in label.lower(): | ||
return {"color": "C2", "linestyle": "dotted"} | ||
if "taylor" in label.lower(): | ||
return {"color": "C0", "linestyle": "solid"} | ||
if "forward" in label.lower(): | ||
return {"color": "C1", "linestyle": "dashed"} | ||
msg = f"Label {label} unknown." | ||
raise ValueError(msg) | ||
|
||
|
||
def plot_results(axis_compile, axis_perform, results): | ||
"""Plot the results.""" | ||
for label, wp in results.items(): | ||
style = choose_style(label) | ||
|
||
inputs = wp["arguments"] | ||
work_mean = wp["work_compile"] | ||
axis_compile.semilogy(inputs, work_mean, label=label, **style) | ||
|
||
work_mean, work_std = wp["work_mean"], wp["work_std"] | ||
work_min, work_max = (wp["work_min"], wp["work_max"]) | ||
range_lower, range_upper = work_mean - work_std, work_mean + work_std | ||
axis_perform.semilogy(inputs, work_mean, label=label, **style) | ||
axis_perform.fill_between(inputs, range_lower, range_upper, alpha=0.3, **style) | ||
|
||
return axis_compile, axis_perform | ||
``` | ||
|
||
```python | ||
plt.rcParams.update(notebook.plot_config()) | ||
|
||
fig, (axis_perform, axis_compile) = plt.subplots( | ||
ncols=2, dpi=150, figsize=(8, 3), sharex=True, tight_layout=True | ||
) | ||
fig.suptitle("FitzHugh-Nagumo problem, Taylor-series estimation") | ||
|
||
results = load_results() | ||
axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results) | ||
|
||
axis_compile.set_title("Compile time") | ||
axis_perform.set_title("Evaluation time") | ||
axis_perform.legend() | ||
axis_compile.set_xlabel("Number of Derivatives") | ||
axis_perform.set_xlabel("Number of Derivatives") | ||
axis_perform.set_ylabel("Wall time (sec)") | ||
axis_perform.grid() | ||
axis_compile.grid() | ||
|
||
plt.show() | ||
``` | ||
|
||
```python | ||
|
||
``` |
Binary file not shown.
175 changes: 175 additions & 0 deletions
175
docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
"""Benchmark the initialisation methods on the FitzHugh-Nagumo problem. | ||
See makefile for instructions. | ||
""" | ||
import argparse | ||
import functools | ||
import os | ||
import statistics | ||
import time | ||
import timeit | ||
from typing import Callable | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import config | ||
|
||
from probdiffeq.impl import impl | ||
from probdiffeq.solvers.taylor import autodiff | ||
from probdiffeq.util.doc_util import info | ||
|
||
|
||
def set_jax_config() -> None: | ||
"""Set JAX and other external libraries up.""" | ||
# x64 precision | ||
config.update("jax_enable_x64", True) | ||
|
||
# CPU | ||
config.update("jax_platform_name", "cpu") | ||
|
||
|
||
def set_probdiffeq_config() -> None: | ||
"""Set probdiffeq up.""" | ||
impl.select("isotropic", ode_shape=(14,)) | ||
|
||
|
||
def print_library_info() -> None: | ||
"""Print the environment info for this benchmark.""" | ||
info.print_info() | ||
print("\n------------------------------------------\n") | ||
|
||
|
||
def parse_arguments() -> argparse.Namespace: | ||
"""Parse the arguments from the command line.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--max_time", type=float) | ||
parser.add_argument("--repeats", type=int, default=3) | ||
parser.add_argument("--save", action=argparse.BooleanOptionalAction) | ||
return parser.parse_args() | ||
|
||
|
||
def timeit_fun_from_args(arguments: argparse.Namespace, /) -> Callable: | ||
"""Construct a timeit-function from the command-line arguments.""" | ||
|
||
def timer(fun, /): | ||
return list(timeit.repeat(fun, number=1, repeat=arguments.repeats)) | ||
|
||
return timer | ||
|
||
|
||
def taylor_mode() -> Callable: | ||
"""Taylor-mode estimation.""" | ||
vf_auto, (u0,) = _fitzhugh_nagumo() | ||
|
||
@functools.partial(jax.jit, static_argnames=["num"]) | ||
def estimate(num): | ||
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num) | ||
return jax.block_until_ready(tcoeffs) | ||
|
||
return estimate | ||
|
||
|
||
def taylor_mode_doubling() -> Callable: | ||
"""Taylor-mode estimation.""" | ||
vf_auto, (u0,) = _fitzhugh_nagumo() | ||
|
||
@functools.partial(jax.jit, static_argnames=["num"]) | ||
def estimate(num): | ||
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num=num) | ||
return jax.block_until_ready(tcoeffs) | ||
|
||
return estimate | ||
|
||
|
||
def forward_mode() -> Callable: | ||
"""Forward-mode estimation.""" | ||
vf_auto, (u0,) = _fitzhugh_nagumo() | ||
|
||
@functools.partial(jax.jit, static_argnames=["num"]) | ||
def estimate(num): | ||
tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num) | ||
return jax.block_until_ready(tcoeffs) | ||
|
||
return estimate | ||
|
||
|
||
def _fitzhugh_nagumo(): | ||
u0 = jnp.asarray([-1.0, 1.0]) | ||
|
||
@jax.jit | ||
def vf_probdiffeq(u, a=0.2, b=0.2, c=3.0): | ||
"""FitzHugh--Nagumo model.""" | ||
du1 = c * (u[0] - u[0] ** 3 / 3 + u[1]) | ||
du2 = -(1.0 / c) * (u[0] - a - b * u[1]) | ||
return jnp.asarray([du1, du2]) | ||
|
||
return vf_probdiffeq, (u0,) | ||
|
||
|
||
def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: | ||
"""Benchmark a function iteratively until a max-time threshold is exceeded.""" | ||
work_compile = [] | ||
work_mean = [] | ||
work_std = [] | ||
work_min = [] | ||
work_median = [] | ||
work_max = [] | ||
arguments = [] | ||
|
||
t0 = time.perf_counter() | ||
arg = 1 | ||
while (elapsed := time.perf_counter() - t0) < max_time: | ||
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time) | ||
t0 = time.perf_counter() | ||
_ = fun(arg) | ||
t1 = time.perf_counter() | ||
time_compile = t1 - t0 | ||
|
||
time_execute = timeit_fun(lambda: fun(arg)) # noqa: B023 | ||
|
||
arguments.append(arg + 1) # plus one, because second-order problem | ||
work_compile.append(time_compile) | ||
work_mean.append(statistics.mean(time_execute)) | ||
work_std.append(statistics.stdev(time_execute)) | ||
work_min.append(min(time_execute)) | ||
work_median.append(statistics.median(time_execute)) | ||
work_max.append(max(time_execute)) | ||
arg += 1 | ||
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time) | ||
return { | ||
"work_median": jnp.asarray(work_median), | ||
"work_mean": jnp.asarray(work_mean), | ||
"work_std": jnp.asarray(work_std), | ||
"work_min": jnp.asarray(work_min), | ||
"work_max": jnp.asarray(work_max), | ||
"work_compile": jnp.asarray(work_compile), | ||
"arguments": jnp.asarray(arguments), | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
set_jax_config() | ||
algorithms = { | ||
r"Forward-mode": forward_mode(), | ||
r"Taylor-mode": taylor_mode(), | ||
r"Taylor-mode (doubling)": taylor_mode_doubling(), | ||
} | ||
|
||
# Compute a reference solution | ||
args = parse_arguments() | ||
timeit_fun = timeit_fun_from_args(args) | ||
|
||
# Compute all work-precision diagrams | ||
results = {} | ||
for label, algo in algorithms.items(): | ||
print("\n") | ||
print(label) | ||
results[label] = adaptive_benchmark( | ||
algo, timeit_fun=timeit_fun, max_time=args.max_time | ||
) | ||
# Save results | ||
if args.save: | ||
jnp.save(os.path.dirname(__file__) + "/results.npy", results) | ||
print("\nSaving successful.\n") | ||
else: | ||
print("\nSkipped saving.\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters