Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taylor-series benchmark on a neural ODE #665

Merged
merged 6 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ docs/benchmarks/vanderpol/*.npy
docs/benchmarks/lotkavolterra/*.npy
docs/benchmarks/taylor_pleiades/*.npy
docs/benchmarks/taylor_fitzhughnagumo/*.npy
docs/benchmarks/taylor_node/*.npy

# IDE stuff
.idea/
Expand Down
63 changes: 25 additions & 38 deletions docs/benchmarks/taylor_fitzhughnagumo/plot.ipynb

Large diffs are not rendered by default.

30 changes: 17 additions & 13 deletions docs/benchmarks/taylor_fitzhughnagumo/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,20 @@ def load_results():
def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "doubling" in label.lower():
return {"color": "C3", "linestyle": "dotted"}
return {"color": "C3", "linestyle": "dotted", "label": label}
if "unroll" in label.lower():
return {"color": "C2", "linestyle": "dashdot"}
return {"color": "C2", "linestyle": "dashdot", "label": label}
if "taylor" in label.lower():
return {"color": "C0", "linestyle": "solid"}
return {"color": "C0", "linestyle": "solid", "label": label}
if "forward" in label.lower():
return {"color": "C1", "linestyle": "dashed"}
return {"color": "C1", "linestyle": "dashed", "label": label}
msg = f"Label {label} unknown."
raise ValueError(msg)


def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
style_curve = {"alpha": 0.85, "markersize": 5}
style_curve = {"alpha": 0.85}
style_area = {"alpha": 0.15}
for label, wp in results.items():
style = choose_style(label)
Expand All @@ -68,17 +68,17 @@ def plot_results(axis_compile, axis_perform, results):
work_std = _adaptive_repeat(work_std, num_repeats)
axis_perform.set_xticks(inputs[::2])

axis_compile.semilogy(inputs, work_compile, label=label, **style, **style_curve)
axis_compile.semilogy(inputs, work_compile, **style, **style_curve)

range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, label=label, **style, **style_curve)
axis_perform.semilogy(inputs, work_mean, **style, **style_curve)
axis_perform.fill_between(
inputs, range_lower, range_upper, **style, **style_area
)

axis_compile.set_xlim((1, 17))
axis_perform.set_yticks((1e-6, 1e-5, 1e-4))
axis_perform.set_ylim((7e-7, 1.5e-4))
axis_compile.set_ylim((5e-3, 8e1))
axis_perform.set_yticks((1e-5, 1e-4))
return axis_compile, axis_perform


Expand All @@ -94,16 +94,16 @@ def _adaptive_repeat(xs, ys):
plt.rcParams.update(notebook.plot_config())

fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, dpi=150, figsize=(8, 3), sharex=True, constrained_layout=True
ncols=2, figsize=(8, 3), dpi=150, sharex=True
)
fig.suptitle("FitzHugh-Nagumo problem, Taylor-series estimation")

results = load_results()

axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)

axis_compile.set_title("Compile time")
axis_compile.set_title("Compilation time")
axis_perform.set_title("Evaluation time")
axis_perform.legend(loc="lower right")
axis_compile.legend(loc="lower right")
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
Expand All @@ -112,3 +112,7 @@ axis_compile.grid()

plt.show()
```

```python

```
2,696 changes: 2,696 additions & 0 deletions docs/benchmarks/taylor_node/plot.ipynb

Large diffs are not rendered by default.

115 changes: 115 additions & 0 deletions docs/benchmarks/taylor_node/plot.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
---
jupyter:
jupytext:
formats: ipynb,md
text_representation:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Taylor-series: Neural ODE problem

```python
"""Benchmark all Taylor-series estimators on a Neural ODE."""

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.config import config

from probdiffeq.util.doc_util import notebook

config.update("jax_platform_name", "cpu")
```

```python
def load_results():
"""Load the results from a file."""
return jnp.load("./results.npy", allow_pickle=True)[()]


def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "doubling" in label.lower():
return {"color": "C3", "linestyle": "dotted", "label": label}
if "unroll" in label.lower():
return {"color": "C2", "linestyle": "dashdot", "label": label}
if "taylor" in label.lower():
return {"color": "C0", "linestyle": "solid"}
if "forward" in label.lower():
return {"color": "C1", "linestyle": "dashed", "label": label}
msg = f"Label {label} unknown."
raise ValueError(msg)


def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
style_curve = {"alpha": 0.85}
style_area = {"alpha": 0.15}
for label, wp in results.items():
style = choose_style(label)

inputs = wp["arguments"]
work_compile = wp["work_compile"]
work_mean, work_std = wp["work_mean"], wp["work_std"]

if "doubling" in label:
num_repeats = jnp.diff(jnp.concatenate((jnp.ones((1,)), inputs)))
inputs = jnp.arange(1, jnp.amax(inputs) * 1)
work_compile = _adaptive_repeat(work_compile, num_repeats)
work_mean = _adaptive_repeat(work_mean, num_repeats)
work_std = _adaptive_repeat(work_std, num_repeats)
# axis_perform.set_xticks(inputs[::2])

axis_compile.semilogy(inputs, work_compile, **style, **style_curve)

range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, **style, **style_curve)
axis_perform.fill_between(
inputs, range_lower, range_upper, **style, **style_area
)

axis_compile.set_xticks(range(1, 15))
axis_compile.set_ylim((1e-3, 1e2))
return axis_compile, axis_perform


def _adaptive_repeat(xs, ys):
"""Repeat the doubling values correctly to create a comprehensible plot."""
zs = []
for x, y in zip(xs, ys):
zs.extend([x] * int(y))
return jnp.asarray(zs)
```

```python
plt.rcParams.update(notebook.plot_config())

fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, figsize=(8, 3), dpi=150, sharex=True
)

results = load_results()
axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)

axis_compile.set_title("Compilation time")
axis_perform.set_title("Evaluation time")
axis_compile.legend(loc="lower right")
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid()
axis_compile.grid()


plt.show()
```

```python

```
193 changes: 193 additions & 0 deletions docs/benchmarks/taylor_node/run_taylor_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Benchmark the initialisation methods on a Neural ODE.

See makefile for instructions.
"""
import argparse
import functools
import os
import statistics
import time
import timeit
from typing import Callable

import jax
import jax.numpy as jnp
from diffeqzoo import backend
from jax import config

from probdiffeq.impl import impl
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.util.doc_util import info


def set_jax_config() -> None:
"""Set JAX and other external libraries up."""
# x64 precision
config.update("jax_enable_x64", True)

# CPU
config.update("jax_platform_name", "cpu")


def set_probdiffeq_config() -> None:
"""Set probdiffeq up."""
impl.select("isotropic", ode_shape=(14,))


def print_library_info() -> None:
"""Print the environment info for this benchmark."""
info.print_info()
print("\n------------------------------------------\n")


def parse_arguments() -> argparse.Namespace:
"""Parse the arguments from the command line."""
parser = argparse.ArgumentParser()
parser.add_argument("--max_time", type=float)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
return parser.parse_args()


def timeit_fun_from_args(arguments: argparse.Namespace, /) -> Callable:
"""Construct a timeit-function from the command-line arguments."""

def timer(fun, /):
return list(timeit.repeat(fun, number=1, repeat=arguments.repeats))

return timer


def taylor_mode() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def taylor_mode_unroll() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_unroll(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def taylor_mode_doubling() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num_doublings=num)
return jax.block_until_ready(tcoeffs)

return estimate


def forward_mode() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate


def _node():
N = 100
M = 100
num_layers = 2

key = jax.random.PRNGKey(seed=1)
key1, key2, key3, key4 = jax.random.split(key, num=4)

u0 = jax.random.uniform(key1, shape=(N,))

weights = jax.random.normal(key2, shape=(num_layers, 2, M, N))
biases1 = jax.random.normal(key3, shape=(num_layers, M))
biases2 = jax.random.normal(key4, shape=(num_layers, N))

fun = jnp.tanh

@jax.jit
def vf(x):
for (w1, w2), b1, b2 in zip(weights, biases1, biases2):
x = fun(w2.T @ fun(w1 @ x + b1) + b2)
return x

return vf, (u0,)


def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
"""Call repeatedly until a time-threshold is exceeded."""
work_compile = []
work_mean = []
work_std = []
arguments = []

t0 = time.perf_counter()
arg = 1
while (elapsed := time.perf_counter() - t0) < max_time:
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
t0 = time.perf_counter()
tcoeffs = fun(arg)
t1 = time.perf_counter()
time_compile = t1 - t0

time_execute = timeit_fun(lambda: fun(arg)) # noqa: B023

arguments.append(len(tcoeffs))
work_compile.append(time_compile)
work_mean.append(statistics.mean(time_execute))
work_std.append(statistics.stdev(time_execute))
arg += 1
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
return {
"work_mean": jnp.asarray(work_mean),
"work_std": jnp.asarray(work_std),
"work_compile": jnp.asarray(work_compile),
"arguments": jnp.asarray(arguments),
}


if __name__ == "__main__":
set_jax_config()
backend.select("jax")
algorithms = {
r"Forward-mode": forward_mode(),
r"Taylor-mode (scan)": taylor_mode(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
}

# Compute a reference solution
args = parse_arguments()
timeit_fun = timeit_fun_from_args(args)

# Compute all work-precision diagrams
results = {}
for label, algo in algorithms.items():
print("\n")
print(label)
results[label] = adaptive_benchmark(
algo, timeit_fun=timeit_fun, max_time=args.max_time
)
# Save results
if args.save:
jnp.save(os.path.dirname(__file__) + "/results.npy", results)
print("\nSaving successful.\n")
else:
print("\nSkipped saving.\n")
Loading