Skip to content

Commit

Permalink
Update other plots in benchmarks to be consistent with Hires (#671)
Browse files Browse the repository at this point in the history
* compute plotting solution in other ivps

* Updated all plots

* Background color in vdp legend
  • Loading branch information
pnkraemer authored Oct 27, 2023
1 parent 2c46215 commit 0ed1316
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 96 deletions.
81 changes: 49 additions & 32 deletions docs/benchmarks/lotkavolterra/plot.ipynb

Large diffs are not rendered by default.

46 changes: 42 additions & 4 deletions docs/benchmarks/lotkavolterra/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def load_results():
return jnp.load("./results.npy", allow_pickle=True)[()]


def load_solution():
"""Load the solution-to-be-plotted from a file."""
ts = jnp.load("./plot_ts.npy")
ys = jnp.load("./plot_ys.npy")
return ts, ys


def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "ProbDiffEq" in label:
Expand All @@ -49,6 +56,7 @@ def choose_style(label):

def plot_results(axis, results):
"""Plot the results."""
axis.set_title("Benchmark")
for label, wp in results.items():
style = choose_style(label)

Expand All @@ -62,18 +70,48 @@ def plot_results(axis, results):
axis.set_xlabel("Precision [relative RMSE]")
axis.set_ylabel("Work [wall time, s]")
axis.grid()
axis.legend(loc="upper center", ncols=3, mode="expand", facecolor="ghostwhite")
axis.set_ylim((1e-5, 1e1))
return axis


def plot_solution(axis, ts, ys, yscale="linear"):
axis.set_title("Lotka-Volterra")
kwargs = {"color": "black", "alpha": 0.85}

axis.plot(
ts, ys[:, 0], linestyle="solid", marker="None", label="Predators", **kwargs
)
axis.plot(ts, ys[:, 1], linestyle="dashed", marker="None", label="Prey", **kwargs)
for y in ys.T:
axis.plot(ts[0], y[0], linestyle="None", marker=".", markersize=4, **kwargs)
axis.plot(ts[-1], y[-1], linestyle="None", marker=".", markersize=4, **kwargs)

axis.set_ylim((-1, 27))
axis.legend(facecolor="ghostwhite", ncols=2, loc="lower center", mode="expand")

axis.set_xlabel("Time $t$")
axis.set_ylabel("Solution $y$")
axis.set_yscale(yscale)
return axis
```

```python
plt.rcParams.update(notebook.plot_config())

fig, axis = plt.subplots(dpi=150, constrained_layout=True)
fig.suptitle("Lotka-Volterra problem, terminal-value simulation")
layout = [
["benchmark", "benchmark", "solution"],
["benchmark", "benchmark", "solution"],
]
fig, axes = plt.subplot_mosaic(layout, figsize=(8, 3), constrained_layout=True, dpi=300)


results = load_results()
axis = plot_results(axis, results)
axis.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ts, ys = load_solution()

_ = plot_results(axes["benchmark"], results)
_ = plot_solution(axes["solution"], ts, ys)

plt.show()
```

Expand Down
29 changes: 29 additions & 0 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def param_to_solution(tol):
return param_to_solution


def plot_ivp_solution():
"""Compute plotting-values for the IVP."""

def vf_scipy(_t, y):
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return np.asarray([dy1, dy2])

u0 = jnp.asarray((20.0, 20.0))
time_span = np.asarray([0.0, 50.0])

tol = 1e-12
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
atol=1e-3 * tol,
rtol=tol,
method="LSODA",
)
return solution.t, solution.y.T


def rmse_relative(expected: jax.Array, *, nugget=1e-5) -> Callable:
"""Compute the relative RMSE."""
expected = jnp.asarray(expected)
Expand Down Expand Up @@ -218,6 +242,9 @@ def parameter_list_to_workprecision(list_of_args, /):
set_jax_config()
print_library_info()

# Simulate once to get plotting code
ts, ys = plot_ivp_solution()

# If we change the probdiffeq-impl halfway through a script, a warning is raised.
# But for this benchmark, such a change is on purpose.
warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -255,6 +282,8 @@ def parameter_list_to_workprecision(list_of_args, /):
# Save results
if args.save:
jnp.save(os.path.dirname(__file__) + "/results.npy", results)
jnp.save(os.path.dirname(__file__) + "/plot_ts.npy", ts)
jnp.save(os.path.dirname(__file__) + "/plot_ys.npy", ys)
print("\nSaving successful.\n")
else:
print("\nSkipped saving.\n")
84 changes: 52 additions & 32 deletions docs/benchmarks/pleiades/plot.ipynb

Large diffs are not rendered by default.

49 changes: 45 additions & 4 deletions docs/benchmarks/pleiades/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def load_results():
return jnp.load("./results.npy", allow_pickle=True)[()]


def load_solution():
"""Load the solution-to-be-plotted from a file."""
ts = jnp.load("./plot_ts.npy")
ys = jnp.load("./plot_ys.npy")
return ts, ys


def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "probdiffeq" in label.lower():
Expand All @@ -51,6 +58,7 @@ def choose_style(label):

def plot_results(axis, results):
"""Plot the results."""
axis.set_title("Benchmark")
for label, wp in results.items():
style = choose_style(label)

Expand All @@ -64,18 +72,51 @@ def plot_results(axis, results):
axis.set_xlabel("Precision [absolute RMSE]")
axis.set_ylabel("Work [wall time, s]")
axis.grid()
axis.legend(
loc="upper center",
ncols=4,
fontsize="x-small",
mode="expand",
facecolor="ghostwhite",
)
axis.set_ylim((1e-3, 2e0))
return axis


def plot_solution(axis, ts, ys, yscale="linear"):
axis.set_title("Pleiades")
kwargs = {"color": "goldenrod", "alpha": 0.85}

axis.plot(ys[:, :7], ys[:, 7:14], linestyle="solid", marker="None", **kwargs)
axis.plot(
ys[0, :7], ys[0, 7:14], linestyle="None", marker=".", markersize=4, **kwargs
)
axis.plot(
ys[-1, :7], ys[-1, 7:14], linestyle="None", marker="*", markersize=8, **kwargs
)

axis.set_xlabel("Time $t$")
axis.set_ylabel("Solution $y$")
axis.set_yscale(yscale)
return axis
```

```python
plt.rcParams.update(notebook.plot_config())

fig, axis = plt.subplots(dpi=150)
fig.suptitle("Pleiades problem, terminal-value simulation")
layout = [
["solution", "benchmark", "benchmark"],
["solution", "benchmark", "benchmark"],
]
fig, axes = plt.subplot_mosaic(layout, figsize=(8, 3), constrained_layout=True, dpi=300)


results = load_results()
axis = plot_results(axis, results)
axis.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ts, ys = load_solution()

_ = plot_results(axes["benchmark"], results)
_ = plot_solution(axes["solution"], ts, ys)

plt.show()
```

Expand Down
49 changes: 49 additions & 0 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,50 @@ def param_to_solution(tol):
return param_to_solution


def plot_ivp_solution():
"""Compute plotting-values for the IVP."""
# fmt: off
u0 = np.asarray(
[
3.0, 3.0, -1.0, -3.00, 2.0, -2.00, 2.0,
3.0, -3.0, 2.0, 0.00, 0.0, -4.00, 4.0,
0.0, 0.0, 0.0, 0.00, 0.0, 1.75, -1.5,
0.0, 0.0, 0.0, -1.25, 1.0, 0.00, 0.0,
]
)
# fmt: on

def vf_scipy(_t, u):
"""Pleiades problem."""
x = u[0:7] # x
y = u[7:14] # y
xi, xj = x[:, None], x[None, :]
yi, yj = y[:, None], y[None, :]
rij = ((xi - xj) ** 2 + (yi - yj) ** 2) ** (3 / 2)
mj = np.arange(1, 8)[None, :]

# Explicitly avoid dividing by zero for scipy's solver
# The JAX solvers divide by zero and turn the NaNs to zeros.
rij = np.where(rij == 0.0, 1.0, rij)
ddx = np.sum((mj * (xj - xi) / rij), axis=1)
ddy = np.sum((mj * (yj - yi) / rij), axis=1)
return np.concatenate((u[14:21], u[21:28], ddx, ddy))

time_span = np.asarray([0.0, 3.0])

tol = 1e-12
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
atol=1e-3 * tol,
rtol=tol,
method="LSODA",
)

return solution.t, solution.y.T


def rmse_absolute(expected: jax.Array) -> Callable:
"""Compute the relative RMSE."""
expected = jnp.asarray(expected)
Expand Down Expand Up @@ -278,6 +322,9 @@ def parameter_list_to_workprecision(list_of_args, /):
set_probdiffeq_config()
print_library_info()

# Simulate once to get plotting code
ts, ys = plot_ivp_solution()

# Read configuration from command line
args = parse_arguments()
tolerances = tolerances_from_args(args)
Expand Down Expand Up @@ -312,6 +359,8 @@ def parameter_list_to_workprecision(list_of_args, /):
# Save results
if args.save:
jnp.save(os.path.dirname(__file__) + "/results.npy", results)
jnp.save(os.path.dirname(__file__) + "/plot_ts.npy", ts)
jnp.save(os.path.dirname(__file__) + "/plot_ys.npy", ys)
print("\nSaving successful.\n")
else:
print("\nSkipped saving.\n")
90 changes: 72 additions & 18 deletions docs/benchmarks/vanderpol/plot.ipynb

Large diffs are not rendered by default.

62 changes: 58 additions & 4 deletions docs/benchmarks/vanderpol/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def load_results():
return jnp.load("./results.npy", allow_pickle=True)[()]


def load_solution():
"""Load the solution-to-be-plotted from a file."""
ts = jnp.load("./plot_ts.npy")
ys = jnp.load("./plot_ys.npy")
return ts, ys


def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "probdiffeq" in label.lower():
Expand All @@ -51,6 +58,7 @@ def choose_style(label):

def plot_results(axis, results):
"""Plot the results."""
axis.set_title("Benchmark")
for label, wp in results.items():
style = choose_style(label)

Expand All @@ -63,19 +71,65 @@ def plot_results(axis, results):

axis.set_xlabel("Precision [absolute RMSE]")
axis.set_ylabel("Work [wall time, s]")
axis.legend(
loc="upper center",
ncols=3,
fontsize="x-small",
mode="expand",
facecolor="ghostwhite",
)
axis.grid()
axis.set_ylim((1e-3, 3e1))
return axis


def plot_solution(axis, ts, ys, yscale="linear"):
axis.set_title("Van-der-Pol (stiffness: $10^5$)")
kwargs = {"alpha": 0.85}

axis.plot(
ts,
ys[:, 0],
label="y",
linestyle="solid",
color="black",
marker="None",
**kwargs,
)
axis.plot(
ts,
ys[:, 1],
label="$\dot y$",
linestyle="dashed",
color="black",
marker="None",
**kwargs,
)

axis.legend(facecolor="ghostwhite")
axis.set_xlabel("Time $t$")
axis.set_ylabel("Solution $y$ [clipped]")
axis.set_yscale(yscale)
axis.set_ylim((-6, 6))
return axis
```

```python
plt.rcParams.update(notebook.plot_config())

fig, axis = plt.subplots(dpi=150)
fig.suptitle("Van-der-Pol problem (stiffness: $10^5$)")
layout = [
["solution", "benchmark", "benchmark"],
["solution", "benchmark", "benchmark"],
]
fig, axes = plt.subplot_mosaic(layout, figsize=(8, 3), constrained_layout=True, dpi=300)


results = load_results()
axis = plot_results(axis, results)
axis.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ts, ys = load_solution()

_ = plot_results(axes["benchmark"], results)
_ = plot_solution(axes["solution"], ts, ys)

plt.show()
```

Expand Down
Loading

0 comments on commit 0ed1316

Please sign in to comment.