Skip to content

Commit

Permalink
feat (data): Add plotting to data generation experiment (default off)
Browse files Browse the repository at this point in the history
Change the helper plot_training_trajectory to a private function,
and make x_smooth optional in plot_training_data.

This function is used ODE fit-eval to compare how well smoothing works,
but also in data generation just to get a sense of noise (where there is
no smoothed trajectory to plot).
  • Loading branch information
Jacob-Stevens-Haas committed Dec 6, 2024
1 parent 346a56c commit e8add63
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
6 changes: 6 additions & 0 deletions src/gen_experiments/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .gridsearch.typing import GridsearchResultDetails
from .odes import ode_setup
from .pdes import pde_setup
from .plotting import plot_training_data
from .typing import Float1D, Float2D, ProbData

INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
Expand All @@ -26,6 +27,7 @@ def gen_data(
noise_rel: Optional[float] = None,
dt: float = 0.01,
t_end: float = 10,
display: bool = False,
) -> dict[str, Any]:
"""Generate random training and test data
Expand All @@ -47,6 +49,7 @@ def gen_data(
None.
dt: time step for sample
t_end: end time of simulation
display: Whether to display graphics of generated data.
Returns:
dictionary of data and descriptive information
Expand Down Expand Up @@ -80,6 +83,9 @@ def gen_data(
dt=dt,
t_end=t_end,
)
if display:
fig = plot_training_data(x_train[0], x_train_true[0])
fig.suptitle("Sample Trajectory")
return {
"data": ProbData(
dt,
Expand Down
20 changes: 14 additions & 6 deletions src/gen_experiments/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def signed_sqrt(x):
fig.tight_layout()


def plot_training_trajectory(
def _plot_training_trajectory(
ax: Axes,
x_train: np.ndarray,
x_true: np.ndarray,
x_smooth: np.ndarray,
x_smooth: np.ndarray | None,
labels: bool = True,
) -> None:
"""Plot a single training trajectory"""
Expand All @@ -141,7 +141,10 @@ def plot_training_trajectory(
color=PAL[1],
**PLOT_KWS,
)
if np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12:
if (
x_smooth is not None
and np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12
):
ax.plot(
x_smooth[:, 0],
x_smooth[:, 1],
Expand Down Expand Up @@ -173,7 +176,10 @@ def plot_training_trajectory(
label="Measured values",
alpha=0.3,
)
if np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12:
if (
x_smooth is not None
and np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12
):
ax.plot(
x_smooth[:, 0],
x_smooth[:, 1],
Expand All @@ -191,7 +197,9 @@ def plot_training_trajectory(
raise ValueError("Can only plot 2d or 3d data.")


def plot_training_data(x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.ndarray):
def plot_training_data(
x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.ndarray | None = None
):
"""Plot training data (and smoothed training data, if different)."""
fig = plt.figure(figsize=(12, 6))
if x_train.shape[-1] == 2:
Expand All @@ -200,7 +208,7 @@ def plot_training_data(x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.nda
ax0 = fig.add_subplot(1, 2, 1, projection="3d")
else:
raise ValueError("Too many or too few coordinates to plot")
plot_training_trajectory(ax0, x_train, x_true, x_smooth)
_plot_training_trajectory(ax0, x_train, x_true, x_smooth)
ax0.legend()
ax0.set(title="Training data")
ax1 = fig.add_subplot(1, 2, 2)
Expand Down

0 comments on commit e8add63

Please sign in to comment.