Skip to content

Commit

Permalink
feat(plots): Also remove ticks when labels=False in train/test plots
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Feb 18, 2024
1 parent e8a75c2 commit fa382fb
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions src/gen_experiments/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import scipy
import seaborn as sns
from matplotlib.axes import Axes

PAL = sns.color_palette("Set1")
PLOT_KWS = {"alpha": 0.7, "linewidth": 3}
Expand Down Expand Up @@ -128,7 +129,7 @@ def signed_sqrt(x):


def plot_training_trajectory(
ax: plt.Axes,
ax: Axes,
x_train: np.ndarray,
x_true: np.ndarray,
x_smooth: np.ndarray,
Expand Down Expand Up @@ -156,6 +157,8 @@ def plot_training_trajectory(
)
if labels:
ax.set(xlabel="$x_0$", ylabel="$x_1$")
else:
ax.set(xticks=[], yticks=[])
elif x_train.shape[1] == 3:
ax.plot(
x_true[:, 0],
Expand Down Expand Up @@ -187,6 +190,8 @@ def plot_training_trajectory(
)
if labels:
ax.set(xlabel="$x$", ylabel="$y$", zlabel="$z$")
else:
ax.set(xticks=[], yticks=[], zticks=[])
else:
raise ValueError("Can only plot 2d or 3d data.")

Expand Down Expand Up @@ -226,7 +231,7 @@ def plot_pde_training_data(last_train, last_train_true, smoothed_last_train):


def plot_test_sim_data_1d_panel(
axs: Sequence[plt.Axes],
axs: Sequence[Axes],
x_test: np.ndarray,
x_sim: np.ndarray,
t_test: np.ndarray,
Expand All @@ -240,31 +245,33 @@ def plot_test_sim_data_1d_panel(


def _plot_test_sim_data_2d(
axs: Annotated[Sequence[plt.Axes], "len=2"],
axs: Annotated[Sequence[Axes], "len=2"],
x_test: np.ndarray,
x_sim: np.ndarray,
labels: bool = True,
) -> None:
axs[0].plot(x_test[:, 0], x_test[:, 1], "k", label="True Trajectory")
if labels:
axs[0].set(xlabel="$x_0$", ylabel="$x_1$")
axs[1].plot(x_sim[:, 0], x_sim[:, 1], "r--", label="Simulation")
if labels:
axs[1].set(xlabel="$x_0$", ylabel="$x_1$")
for ax in axs:
if labels:
ax.set(xlabel="$x_0$", ylabel="$x_1$")
else:
ax.set(xticks=[], yticks=[])


def _plot_test_sim_data_3d(
axs: Annotated[Sequence[plt.Axes], "len=3"],
axs: Annotated[Sequence[Axes], "len=3"],
x_test: np.ndarray,
x_sim: np.ndarray,
labels: bool = True,
) -> None:
axs[0].plot(x_test[:, 0], x_test[:, 1], x_test[:, 2], "k", label="True Trajectory")
if labels:
axs[0].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$")
axs[1].plot(x_sim[:, 0], x_sim[:, 1], x_sim[:, 2], "r--", label="Simulation")
if labels:
axs[1].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$")
for ax in axs:
if labels:
ax.set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$")
else:
ax.set(xticks=[], yticks=[], zticks=[])


def plot_test_trajectories(
Expand Down

0 comments on commit fa382fb

Please sign in to comment.