Skip to content

Commit

Permalink
Add quickstart
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Jun 7, 2024
1 parent 582a00d commit e033f03
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,56 @@

<h4 align="center">A benchmark for Autoregressive PDE Emulators in <a href="https://github.com/google/jax" target="_blank">JAX</a>.</h4>

### Quickstart

Train a ConvNet to emulate 1D advection, display train loss, test error metric
rollout, and a sample rollout.

```python
import apebench
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

advection_scenario = apebench.scenarios.difficulty.Advection()

data, trained_nets = advection_scenario(
task_config="predict",
network_config="Conv;26;10;relu",
train_config="one",
num_seeds=3,
)

data_loss = apebench.melt_loss(data)
data_metrics = apebench.melt_metrics(data)
data_sample_rollout = apebench.melt_sample_rollouts(data)

fig, axs = plt.subplots(3, 1, figsize=(6, 12))

sns.lineplot(data_loss, x="update_step", y="train_loss", ax=axs[0])
axs[0].set_yscale("log")
axs[0].set_title("Training loss")

sns.lineplot(data_metrics, x="time_step", y="mean_nRMSE", ax=axs[1])
axs[1].set_ylim(-0.05, 1.05)
axs[1].set_title("Metric rollout")

axs[2].imshow(
np.array(data_sample_rollout["sample_rollout"][0])[:, 0, :].T,
origin="lower",
aspect="auto",
vmin=-1,
vmax=1,
cmap="RdBu_r",
)
axs[2].set_xlabel("time")
axs[2].set_ylabel("space")
axs[2].set_title("Sample rollout")

plt.show()
```

### More details

Autoregressive neural emulators can be used to efficiently forecast transient
phenomena, often associated with differential equations. Denote by
Expand Down

0 comments on commit e033f03

Please sign in to comment.