Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify taylor_mode_doubling implementation and tweak FHN benchmark #657

Merged
merged 9 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions docs/benchmarks/taylor_fitzhughnagumo/plot.ipynb

Large diffs are not rendered by default.

45 changes: 32 additions & 13 deletions docs/benchmarks/taylor_fitzhughnagumo/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,50 @@ def choose_style(label):

def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
style_curve = {"alpha": 0.85, "markersize": 5}
style_area = {"alpha": 0.15}
for label, wp in results.items():
style = choose_style(label)

inputs = wp["arguments"]
work_mean = wp["work_compile"]
axis_compile.semilogy(inputs, work_mean, label=label, **style)

work_compile = wp["work_compile"]
work_mean, work_std = wp["work_mean"], wp["work_std"]
work_min, work_max = (wp["work_min"], wp["work_max"])
range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, label=label, **style)
axis_perform.fill_between(inputs, range_lower, range_upper, alpha=0.3, **style)

if "doubling" in label:
num_repeats = jnp.diff(jnp.concatenate((jnp.ones((1,)), inputs)))
inputs = jnp.arange(1, jnp.amax(inputs) * 1)
work_compile = _adaptive_repeat(work_compile, num_repeats)
work_mean = _adaptive_repeat(work_mean, num_repeats)
work_std = _adaptive_repeat(work_std, num_repeats)
axis_perform.set_xticks(inputs[::2])

axis_compile.semilogy(inputs, work_compile, label=label, **style, **style_curve)

range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, label=label, **style, **style_curve)
axis_perform.fill_between(
inputs, range_lower, range_upper, **style, **style_area
)

axis_compile.set_xlim((1, 17))
axis_perform.set_yticks((1e-6, 1e-5, 1e-4))
axis_perform.set_ylim((7e-7, 1.5e-4))
return axis_compile, axis_perform


def _adaptive_repeat(xs, ys):
"""Repeat the doubling values correctly to create a comprehensible plot."""
zs = []
for x, y in zip(xs, ys):
zs.extend([x] * int(y))
return jnp.asarray(zs)
```

```python
plt.rcParams.update(notebook.plot_config())

fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, dpi=150, figsize=(8, 3), sharex=True, tight_layout=True
ncols=2, dpi=150, figsize=(8, 3), sharex=True, constrained_layout=True
)
fig.suptitle("FitzHugh-Nagumo problem, Taylor-series estimation")

Expand All @@ -78,7 +101,7 @@ axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)

axis_compile.set_title("Compile time")
axis_perform.set_title("Evaluation time")
axis_perform.legend()
axis_perform.legend(loc="lower right")
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
Expand All @@ -87,7 +110,3 @@ axis_compile.grid()

plt.show()
```

```python

```
Binary file modified docs/benchmarks/taylor_fitzhughnagumo/results.npy
Binary file not shown.
15 changes: 3 additions & 12 deletions docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def taylor_mode_doubling() -> Callable:

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num=num)
tcoeffs = autodiff.taylor_mode_doubling(vf_auto, (u0,), num_doublings=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -111,37 +111,28 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
work_compile = []
work_mean = []
work_std = []
work_min = []
work_median = []
work_max = []
arguments = []

t0 = time.perf_counter()
arg = 1
while (elapsed := time.perf_counter() - t0) < max_time:
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
t0 = time.perf_counter()
_ = fun(arg)
tcoeffs = fun(arg)
t1 = time.perf_counter()
time_compile = t1 - t0

time_execute = timeit_fun(lambda: fun(arg)) # noqa: B023

arguments.append(arg + 1) # plus one, because second-order problem
arguments.append(len(tcoeffs))
work_compile.append(time_compile)
work_mean.append(statistics.mean(time_execute))
work_std.append(statistics.stdev(time_execute))
work_min.append(min(time_execute))
work_median.append(statistics.median(time_execute))
work_max.append(max(time_execute))
arg += 1
print("num =", arg, "| elapsed =", elapsed, "| max_time =", max_time)
return {
"work_median": jnp.asarray(work_median),
"work_mean": jnp.asarray(work_mean),
"work_std": jnp.asarray(work_std),
"work_min": jnp.asarray(work_min),
"work_max": jnp.asarray(work_max),
"work_compile": jnp.asarray(work_compile),
"arguments": jnp.asarray(arguments),
}
Expand Down
26 changes: 13 additions & 13 deletions docs/benchmarks/taylor_pleiades/plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"id": "daa53abe",
"metadata": {
"execution": {
"iopub.execute_input": "2023-10-03T13:09:38.241381Z",
"iopub.status.busy": "2023-10-03T13:09:38.241063Z",
"iopub.status.idle": "2023-10-03T13:09:38.951898Z",
"shell.execute_reply": "2023-10-03T13:09:38.951250Z"
"iopub.execute_input": "2023-10-04T07:21:45.767181Z",
"iopub.status.busy": "2023-10-04T07:21:45.766375Z",
"iopub.status.idle": "2023-10-04T07:21:46.659606Z",
"shell.execute_reply": "2023-10-04T07:21:46.658861Z"
}
},
"outputs": [],
Expand All @@ -41,10 +41,10 @@
"id": "6ae70671",
"metadata": {
"execution": {
"iopub.execute_input": "2023-10-03T13:09:38.954963Z",
"iopub.status.busy": "2023-10-03T13:09:38.954692Z",
"iopub.status.idle": "2023-10-03T13:09:38.960305Z",
"shell.execute_reply": "2023-10-03T13:09:38.959647Z"
"iopub.execute_input": "2023-10-04T07:21:46.662850Z",
"iopub.status.busy": "2023-10-04T07:21:46.662555Z",
"iopub.status.idle": "2023-10-04T07:21:46.669331Z",
"shell.execute_reply": "2023-10-04T07:21:46.668624Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -87,10 +87,10 @@
"id": "c6df444c-a339-40bd-ba62-6dda18b50ad6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-10-03T13:09:38.963166Z",
"iopub.status.busy": "2023-10-03T13:09:38.962907Z",
"iopub.status.idle": "2023-10-03T13:09:40.045928Z",
"shell.execute_reply": "2023-10-03T13:09:40.045188Z"
"iopub.execute_input": "2023-10-04T07:21:46.672300Z",
"iopub.status.busy": "2023-10-04T07:21:46.672049Z",
"iopub.status.idle": "2023-10-04T07:21:47.926884Z",
"shell.execute_reply": "2023-10-04T07:21:47.926198Z"
}
},
"outputs": [
Expand All @@ -99,7 +99,7 @@
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1696338579.002870 393629 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n"
"I0000 00:00:1696404106.717187 449810 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n"
]
},
{
Expand Down
15 changes: 14 additions & 1 deletion docs/dev_docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
# Change log

## v0.2.1

**Breaking changes:**

* The input-argument to `taylor_mode_doubling` is `num_doublings` instead of `num`.
This argument behaves differently to e.g., `taylor_mode(..., num)`.
Since doubling is experimental, the version is not bumped to `v0.3.0`.


**New features:**

**Notable bug-fixes:**

## v0.2.0

This version overhauls large parts of the API.
Consider the quickstart for an introduction about the "new" way of doing things.
From now on, this change log will be used properly.

Notable bug fixes:
**Notable bug-fixes:**

* The log-pdf behaviour of Gaussian random variables has been corrected (previously, the returned values were slightly incorrect).
This means that the behaviour of, e.g., parameter estimation scripts will change slightly.
Expand Down
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ example:
jupytext --quiet --sync docs/examples_parameter_estimation/*

run-benchmarks:
time python docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py --max_time 20 --repeats 15 --save
time python docs/benchmarks/taylor_fitzhughnagumo/run_taylor_fitzhughnagumo.py --max_time 15 --repeats 15 --save
jupytext --quiet --sync docs/benchmarks/taylor_fitzhughnagumo/*.ipynb
jupytext --quiet --execute docs/benchmarks/taylor_fitzhughnagumo/*.ipynb
time python docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py --max_time 15 --repeats 5 --save
Expand Down
4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ plugins:
- benchmarks/taylor_pleiades/*.md
- benchmarks/taylor_pleiades/*.py
- benchmarks/taylor_pleiades/.ipynb_checkpoints/*
- benchmarks/taylor_fitzhughnagumo/*.md
- benchmarks/taylor_fitzhughnagumo/*.py
- benchmarks/taylor_fitzhughnagumo/.ipynb_checkpoints/*
- mkdocs-jupyter
extra:
social:
Expand Down Expand Up @@ -160,3 +163,4 @@ nav:
- benchmarks/pleiades/plot.ipynb
- benchmarks/hires/plot.ipynb
- benchmarks/taylor_pleiades/plot.ipynb
- benchmarks/taylor_fitzhughnagumo/plot.ipynb
27 changes: 13 additions & 14 deletions probdiffeq/solvers/taylor/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
r"""Taylor-expand the solution of an initial value problem (IVP)."""

import functools
import itertools
from typing import Callable, Tuple

import jax
import jax.experimental.jet
import jax.experimental.ode
import jax.numpy as jnp

# TODO: split into subpackage


@functools.partial(jax.jit, static_argnums=[0], static_argnames=["num"])
def taylor_mode(vf: Callable, initial_values: Tuple, /, num: int):
"""Taylor-expand the solution of an IVP with Taylor-mode differentiation."""
# Number of positional arguments in f
Expand Down Expand Up @@ -76,7 +73,6 @@ def mask(i):
return [x[mask(k) : mask(k + 1 - n)] for k in range(n)]


@functools.partial(jax.jit, static_argnums=[0], static_argnames=["num"])
def forward_mode(vf: Callable, initial_values: Tuple, /, num: int):
"""Taylor-expand the solution of an IVP with forward-mode differentiation.

Expand Down Expand Up @@ -108,8 +104,7 @@ def df(*args):
return jax.tree_util.Partial(df)


@functools.partial(jax.jit, static_argnums=[0], static_argnames=["num"])
def taylor_mode_doubling(vf: Callable, initial_values: Tuple, /, num: int):
def taylor_mode_doubling(vf: Callable, initial_values: Tuple, /, num_doublings: int):
"""Combine Taylor-mode differentiation and Newton's doubling.

!!! warning "Warning: highly EXPERIMENTAL feature!"
Expand All @@ -119,7 +114,7 @@ def taylor_mode_doubling(vf: Callable, initial_values: Tuple, /, num: int):
and without any deprecation policy.

!!! warning "Compilation time"
JIT-compiling this function unrolls a loop of length `num`.
JIT-compiling this function unrolls a loop.

"""
(u0,) = initial_values
Expand All @@ -142,15 +137,17 @@ def jet_embedded(*c, degree):
return _normalise(p_new, *s_new)

taylor_coefficients = [u0]
while (deg := len(taylor_coefficients)) < num + 1:
degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings))))
for deg in degrees:
jet_embedded_deg = jax.tree_util.Partial(jet_embedded, degree=deg)
fx, jvp = jax.linearize(jet_embedded_deg, *taylor_coefficients)

# Compute the next set of coefficients.
# TODO: can we jax.fori_loop() this loop?
# the running variable (cs_padded) should have constant size
cs = [(fx[deg - 1] / deg)]
for k in range(deg, min(2 * deg, num)):
cs_padded = cs + [zeros] * (deg - 1)
for i, fx_i in enumerate(fx[deg : 2 * deg]):
# The Jacobian of the embedded jet is block-banded,
# i.e., of the form (for j=3)
# (A0, 0, 0; A1, A0, 0; A2, A1, A0; *, *, *; *, *, *; *, *, *)
Expand All @@ -161,12 +158,14 @@ def jet_embedded(*c, degree):
# Bettencourt et al. (2019;
# "Taylor-mode autodiff for higher-order derivatives in JAX")
# explain details.
cs_padded = cs + [zeros] * (2 * deg - k - 1)
linear_combination = jvp(*cs_padded)[k - deg]
cs += [(fx[k] + linear_combination) / (k + 1)]
# i = k - deg
linear_combination = jvp(*cs_padded)[i]
cs_ = cs_padded[: (i + 1)]
cs_ += [(fx_i + linear_combination) / (i + deg + 1)]
cs_padded = cs_ + [zeros] * (deg - i - 2)

# Store all new coefficients
taylor_coefficients.extend(cs)
taylor_coefficients.extend(cs_padded)

return _unnormalise(*taylor_coefficients)

Expand Down
19 changes: 13 additions & 6 deletions tests/test_solvers/test_taylor/test_exact_first_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ def case_taylor_mode():
return autodiff.taylor_mode


@testing.case()
def case_taylor_mode_doubling():
return autodiff.taylor_mode_doubling


@testing.fixture(name="pb_with_solution")
def fixture_pb_with_solution():
f, u0, (t0, _), f_args = diffeqzoo.ivps.three_body_restricted_first_order()
Expand All @@ -36,10 +31,22 @@ def vf(u, /):


@testing.parametrize_with_cases("taylor_fun", cases=".", prefix="case_")
@testing.parametrize("num", [1, 3])
@testing.parametrize("num", [1, 6])
def test_approximation_identical_to_reference(pb_with_solution, taylor_fun, num):
(f, init), solution = pb_with_solution

derivatives = taylor_fun(f, init, num=num)
assert len(derivatives) == num + 1
for dy, dy_ref in zip(derivatives, solution):
assert jnp.allclose(dy, dy_ref)


@testing.parametrize("num_doublings", [1, 2])
def test_approximation_identical_to_reference_doubling(pb_with_solution, num_doublings):
"""Separately test the doubling-function, because its API is different."""
(f, init), solution = pb_with_solution

derivatives = autodiff.taylor_mode_doubling(f, init, num_doublings=num_doublings)
assert len(derivatives) == jnp.sum(2 ** jnp.arange(num_doublings + 1))
for dy, dy_ref in zip(derivatives, solution):
assert jnp.allclose(dy, dy_ref)