From a5343aa6bdadeb4b93efc7fa577345df75003835 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:56:40 +0000 Subject: [PATCH 01/46] cln: specify generics --- src/gen_experiments/odes.py | 4 +++- src/gen_experiments/utils.py | 11 ++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index 0f81159..25bc903 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -32,7 +32,9 @@ } -def nonlinear_pendulum(t, x, m=1, L=1, g=9.81, forcing=0, return_all=True): +def nonlinear_pendulum( + t, x, m=1, L=1, g=9.81, forcing=0, return_all=True +): # type:ignore """Simple pendulum equation of motion Arguments: diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 69f6cb8..b631766 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -6,6 +6,7 @@ from types import EllipsisType as ellipsis from typing import ( Annotated, + Any, Callable, Collection, Literal, @@ -26,9 +27,9 @@ import scipy import seaborn as sns import sklearn +from matplotlib.axes._axes import Axes from matplotlib.figure import Figure from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec, SubplotSpec -from matplotlib.pyplot import Axes from numpy.typing import DTypeLike, NDArray INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12} @@ -491,8 +492,8 @@ def integration_metrics(model, x_test, t_train, x_dot_test): def unionize_coeff_matrices( - model: ps.SINDy, coeff_true: list[dict] -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + model: ps.SINDy, coeff_true: list[dict[str, float]] +) -> tuple[NDArray[np.float64], NDArray[np.float64], list[str]]: """Reformat true coefficients and coefficient matrix compatibly In order to calculate accuracy metrics between true and estimated @@ -1275,9 +1276,9 @@ def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> def _grid_locator_match( - exp_params: dict, + exp_params: dict[str, Any], exp_ind: tuple[int, ...], - param_spec: Collection[dict], + param_spec: Collection[dict[str, Any]], ind_spec: Collection[tuple[int, ...]], ) -> bool: """Determine whether experimental parameters match a specification From fa7f306c6fc4d1e75467b771e82e207d48b25188 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:59:22 +0000 Subject: [PATCH 02/46] feat(utils): Add helper grid locator function --- src/gen_experiments/utils.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index b631766..d82e17f 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -1318,6 +1318,36 @@ def _grid_locator_match( return found_match +def _strict_find_grid_match( + results: GridsearchResultDetails, + *, + params: Optional[dict[str, Any]] = None, + ind_spec: Optional[tuple[int | slice, int] | ellipsis] = None, +) -> TrialData: + if params is None: + params = {} + if ind_spec is None: + ind_spec = ... + matches = [] + amax_arrays = [ + [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] + for _, single_series_all_axes in results["series_data"].items() + ] + full_inds = _amax_to_full_inds((ind_spec,), amax_arrays) + + for trajectory in results["plot_data"]: + if _grid_locator_match( + trajectory["params"], trajectory["pind"], (params,), full_inds + ): + matches.append(trajectory) + + if len(matches) > 1: + raise ValueError("Specification is nonunique; matched multiple results") + if len(matches) == 0: + raise ValueError("Could not find a match") + return matches[0]["data"] + + _EqTester = TypeVar("_EqTester") From d7f91f2c8034907f4fe15abc96422b68bf36bf43 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 14 Feb 2024 14:30:23 +0000 Subject: [PATCH 03/46] cln(utils): split utils into plotting and data --- src/gen_experiments/config.py | 11 +- src/gen_experiments/data.py | 235 +++++++++ src/gen_experiments/gridsearch.py | 2 +- src/gen_experiments/odes.py | 10 +- src/gen_experiments/pdes.py | 5 +- src/gen_experiments/plotting.py | 294 +++++++++++ src/gen_experiments/utils.py | 786 +----------------------------- 7 files changed, 544 insertions(+), 799 deletions(-) create mode 100644 src/gen_experiments/data.py create mode 100644 src/gen_experiments/plotting.py diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 0d017cb..9efb66f 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -3,14 +3,9 @@ import numpy as np import pysindy as ps -from gen_experiments.utils import ( - FullTrialData, - NestedDict, - SeriesDef, - SeriesList, - _PlotPrefs, - _signal_avg_power, -) +from gen_experiments.data import _signal_avg_power +from gen_experiments.plotting import _PlotPrefs +from gen_experiments.utils import FullTrialData, NestedDict, SeriesDef, SeriesList T = TypeVar("T") U = TypeVar("U") diff --git a/src/gen_experiments/data.py b/src/gen_experiments/data.py new file mode 100644 index 0000000..51fc6e4 --- /dev/null +++ b/src/gen_experiments/data.py @@ -0,0 +1,235 @@ +from math import ceil +from pathlib import Path +from typing import Callable +from warnings import warn + +import mitosis +import numpy as np +import scipy + +from gen_experiments.utils import GridsearchResultDetails + +INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12} +TRIALS_FOLDER = Path(__file__).parent.absolute() / "trials" + + +def gen_data( + rhs_func, + n_coord, + seed=None, + n_trajectories=1, + x0_center=None, + ic_stdev=3, + noise_abs=None, + noise_rel=None, + nonnegative=False, + dt=0.01, + t_end=10, +): + """Generate random training and test data + + Note that test data has no noise. + + Arguments: + rhs_func (Callable): the function to integrate + n_coord (int): number of coordinates needed for rhs_func + seed (int): the random seed for number generation + n_trajectories (int): number of trajectories of training data + x0_center (np.array): center of random initial conditions + ic_stdev (float): standard deviation for generating initial + conditions + noise_abs (float): measurement noise standard deviation. + Defaults to .1 if noise_rel is None. + noise_rel (float): measurement noise-to-signal power ratio. + Either noise_abs or noise_rel must be None. Defaults to + None. + nonnegative (bool): Whether x0 must be nonnegative, such as for + population models. If so, a gamma distribution is + used, rather than a normal distribution. + + Returns: + dt, t_train, x_train, x_test, x_dot_test, x_train_true + """ + if noise_abs is not None and noise_rel is not None: + raise ValueError("Cannot specify both noise_abs and noise_rel") + elif noise_abs is None and noise_rel is None: + noise_abs = 0.1 + rng = np.random.default_rng(seed) + if x0_center is None: + x0_center = np.zeros((n_coord)) + t_train = np.arange(0, t_end, dt) + t_train_span = (t_train[0], t_train[-1]) + if nonnegative: + shape = ((x0_center + 1) / ic_stdev) ** 2 + scale = ic_stdev**2 / (x0_center + 1) + x0_train = np.array( + [rng.gamma(k, theta, n_trajectories) for k, theta in zip(shape, scale)] + ).T + x0_test = np.array([ + rng.gamma(k, theta, ceil(n_trajectories / 2)) + for k, theta in zip(shape, scale) + ]).T + else: + x0_train = ic_stdev * rng.standard_normal((n_trajectories, n_coord)) + x0_center + x0_test = ( + ic_stdev * rng.standard_normal((ceil(n_trajectories / 2), n_coord)) + + x0_center + ) + x_train = [] + for traj in range(n_trajectories): + x_train.append( + scipy.integrate.solve_ivp( + rhs_func, + t_train_span, + x0_train[traj, :], + t_eval=t_train, + **INTEGRATOR_KEYWORDS, + ).y.T + ) + + def _drop_and_warn(arrs): + maxlen = max(arr.shape[0] for arr in arrs) + + def _alert_short(arr): + if arr.shape[0] < maxlen: + warn(message="Dropping simulation due to blow-up") + return False + return True + + arrs = list(filter(_alert_short, arrs)) + if len(arrs) == 0: + raise ValueError( + "Simulations failed due to blow-up. System is too stiff for solver's" + " numerical tolerance" + ) + return arrs + + x_train = _drop_and_warn(x_train) + x_train = np.stack(x_train) + x_test = [] + for traj in range(ceil(n_trajectories / 2)): + x_test.append( + scipy.integrate.solve_ivp( + rhs_func, + t_train_span, + x0_test[traj, :], + t_eval=t_train, + **INTEGRATOR_KEYWORDS, + ).y.T + ) + x_test = _drop_and_warn(x_test) + x_test = np.array(x_test) + x_dot_test = np.array([[rhs_func(0, xij) for xij in xi] for xi in x_test]) + x_train_true = np.copy(x_train) + if noise_rel is not None: + noise_abs = np.sqrt(_signal_avg_power(x_test) * noise_rel) + x_train = x_train + noise_abs * rng.standard_normal(x_train.shape) + x_train = list(x_train) + x_test = list(x_test) + x_dot_test = list(x_dot_test) + return dt, t_train, x_train, x_test, x_dot_test, x_train_true + + +def gen_pde_data( + rhs_func: Callable, + init_cond: np.ndarray, + args: tuple, + dimension: int, + seed: int | None = None, + noise_abs: float | None = None, + noise_rel: float | None = None, + dt: float = 0.01, + t_end: int = 100, +): + """Generate PDE measurement data for training + + For simplicity, Trajectories have been removed, + Test data is the same as Train data. + + Arguments: + rhs_func: the function to integrate + init_cond: Initial Conditions for the PDE + args: Arguments for rhsfunc + dimension: Number of spatial dimensions (1, 2, or 3) + seed (int): the random seed for number generation + noise_abs (float): measurement noise standard deviation. + Defaults to .1 if noise_rel is None. + noise_rel (float): measurement noise relative to amplitude of + true data. Amplitude of data is calculated as the max value + of the power spectrum. Either noise_abs or noise_rel must + be None. Defaults to None. + dt (float): time step for the PDE simulation + t_end (int): total time for the PDE simulation + + Returns: + dt, t_train, x_train, x_test, x_dot_test, x_train_true + """ + if noise_abs is not None and noise_rel is not None: + raise ValueError("Cannot specify both noise_abs and noise_rel") + elif noise_abs is None and noise_rel is None: + noise_abs = 0.1 + rng = np.random.default_rng(seed) + t_train = np.arange(0, t_end, dt) + t_train_span = (t_train[0], t_train[-1]) + x_train = [] + x_train.append( + scipy.integrate.solve_ivp( + rhs_func, + t_train_span, + init_cond, + t_eval=t_train, + args=args, + **INTEGRATOR_KEYWORDS, + ).y.T + ) + t, x = x_train[0].shape + x_train = np.stack(x_train, axis=-1) + if dimension == 1: + pass + elif dimension == 2: + x_train = np.reshape(x_train, (t, int(np.sqrt(x)), int(np.sqrt(x)), 1)) + elif dimension == 3: + x_train = np.reshape( + x_train, (t, int(np.cbrt(x)), int(np.cbrt(x)), int(np.cbrt(x)), 1) + ) + x_test = x_train + x_test = np.moveaxis(x_test, -1, 0) + x_dot_test = np.array( + [[rhs_func(0, xij, args[0], args[1]) for xij in xi] for xi in x_test] + ) + if dimension == 1: + x_dot_test = [np.moveaxis(x_dot_test, [0, 1], [-1, -2])] + pass + elif dimension == 2: + x_dot_test = np.reshape(x_dot_test, (t, int(np.sqrt(x)), int(np.sqrt(x)), 1)) + x_dot_test = [np.moveaxis(x_dot_test, 0, -2)] + elif dimension == 3: + x_dot_test = np.reshape( + x_dot_test, (t, int(np.cbrt(x)), int(np.cbrt(x)), int(np.cbrt(x)), 1) + ) + x_dot_test = [np.moveaxis(x_dot_test, 0, -2)] + x_train_true = np.copy(x_train) + if noise_rel is not None: + noise_abs = _max_amplitude(x_test) * noise_rel + x_train = x_train + noise_abs * rng.standard_normal(x_train.shape) + x_train = [np.moveaxis(x_train, 0, -2)] + x_train_true = np.moveaxis(x_train_true, 0, -2) + x_test = [np.moveaxis(x_test, [0, 1], [-1, -2])] + return dt, t_train, x_train, x_test, x_dot_test, x_train_true + + +def _max_amplitude(signal: np.ndarray): + return np.abs(scipy.fft.rfft(signal, axis=0)[1:]).max() / np.sqrt(len(signal)) + + +def _signal_avg_power(signal: np.ndarray) -> float: + return np.square(signal).mean() + + +def load_results(hexstr: str) -> GridsearchResultDetails: + """Load the results that mitosis saves + + Args: + hexstr: randomly-assigned identifier for the results to open + """ + return mitosis.load_trial_data(hexstr, trials_folder=TRIALS_FOLDER) diff --git a/src/gen_experiments/gridsearch.py b/src/gen_experiments/gridsearch.py index ee86082..676ad49 100644 --- a/src/gen_experiments/gridsearch.py +++ b/src/gen_experiments/gridsearch.py @@ -12,6 +12,7 @@ import gen_experiments from gen_experiments import config from gen_experiments.odes import plot_ode_panel +from gen_experiments.plotting import _PlotPrefs from gen_experiments.utils import ( GridsearchResult, GridsearchResultDetails, @@ -23,7 +24,6 @@ _amax_to_full_inds, _argopt, _grid_locator_match, - _PlotPrefs, simulate_test_data, ) diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index 25bc903..a054b2d 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -5,16 +5,18 @@ import pysindy as ps from . import config +from .data import gen_data +from .plotting import ( + compare_coefficient_plots, + plot_test_trajectories, + plot_training_data, +) from .utils import ( FullTrialData, TrialData, _make_model, coeff_metrics, - compare_coefficient_plots, - gen_data, integration_metrics, - plot_test_trajectories, - plot_training_data, simulate_test_data, unionize_coeff_matrices, ) diff --git a/src/gen_experiments/pdes.py b/src/gen_experiments/pdes.py index 4170bac..c7f905e 100644 --- a/src/gen_experiments/pdes.py +++ b/src/gen_experiments/pdes.py @@ -2,15 +2,14 @@ from pysindy.differentiation import SpectralDerivative from . import config +from .data import gen_pde_data +from .plotting import compare_coefficient_plots, plot_pde_training_data from .utils import ( FullTrialData, TrialData, _make_model, coeff_metrics, - compare_coefficient_plots, - gen_pde_data, integration_metrics, - plot_pde_training_data, simulate_test_data, unionize_coeff_matrices, ) diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py new file mode 100644 index 0000000..4a54230 --- /dev/null +++ b/src/gen_experiments/plotting.py @@ -0,0 +1,294 @@ +from dataclasses import dataclass, field +from types import EllipsisType as ellipsis +from typing import Annotated, Callable, Collection, Literal, Mapping, Sequence + +import matplotlib.pyplot as plt +import numpy as np +import scipy +import seaborn as sns + +PAL = sns.color_palette("Set1") +PLOT_KWS = {"alpha": 0.7, "linewidth": 3} + + +@dataclass(frozen=True) +class _PlotPrefs: + """Control which gridsearch data gets plotted, and a bit of how + + Args: + plot: whether to plot + rel_noise: Whether and how to convert true noise into relative noise + grid_params_match: dictionaries of parameters to match when plotted. OR + is applied across the collection + grid_ind_match: indexing tuple to match indices in a single series + gridsearch. Only positive integers are allowed, except the first + element may be slice(None). Alternatively, ellipsis to match all + indices + """ + + plot: bool = True + rel_noise: Literal[False] | Callable = False + grid_params_match: Collection[dict] = field(default_factory=lambda: ()) + grid_ind_match: Collection[tuple[int | slice, int]] | ellipsis = field( + default_factory=lambda: ... + ) + + def __bool__(self): + return self.plot + + +def plot_coefficients( + coefficients: Annotated[np.ndarray, "(n_coord, n_features)"], + input_features: Sequence[str] = None, + feature_names: Sequence[str] = None, + ax: bool = None, + **heatmap_kws, +): + if input_features is None: + input_features = [r"$\dot x_" + f"{k}$" for k in range(coefficients.shape[0])] + else: + input_features = [r"$\dot " + f"{fi}$" for fi in input_features] + + if feature_names is None: + feature_names = [f"f{k}" for k in range(coefficients.shape[1])] + + with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}): + if ax is None: + fig, ax = plt.subplots(1, 1) + + heatmap_args = { + "xticklabels": input_features, + "yticklabels": feature_names, + "center": 0.0, + "cmap": sns.color_palette("vlag", n_colors=20, as_cmap=True), + "ax": ax, + "linewidths": 0.1, + "linecolor": "whitesmoke", + } + heatmap_args.update(**heatmap_kws) + + sns.heatmap(coefficients.T, **heatmap_args) + + ax.tick_params(axis="y", rotation=0) + + return ax + + +def compare_coefficient_plots( + coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"], + coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"], + input_features: Sequence[str] = None, + feature_names: Sequence[str] = None, +): + """Create plots of true and estimated coefficients.""" + n_cols = len(coefficients_est) + + # helps boost the color of small coefficients. Maybe log is better? + def signed_sqrt(x): + return np.sign(x) * np.sqrt(np.abs(x)) + + with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}): + fig, axs = plt.subplots( + 1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True + ) + + max_clean = max(np.max(np.abs(c)) for c in coefficients_est) + max_noisy = max(np.max(np.abs(c)) for c in coefficients_true) + max_mag = np.sqrt(max(max_clean, max_noisy)) + + plot_coefficients( + signed_sqrt(coefficients_true), + input_features=input_features, + feature_names=feature_names, + ax=axs[0], + cbar=False, + vmax=max_mag, + vmin=-max_mag, + ) + + plot_coefficients( + signed_sqrt(coefficients_est), + input_features=input_features, + feature_names=feature_names, + ax=axs[1], + cbar=False, + ) + + axs[0].set_title("True Coefficients", rotation=45) + axs[1].set_title("Est. Coefficients", rotation=45) + + fig.tight_layout() + + +def plot_training_trajectory( + ax: plt.Axes, + x_train: np.ndarray, + x_true: np.ndarray, + x_smooth: np.ndarray, + labels: bool = True, +) -> None: + """Plot a single training trajectory""" + if x_train.shape[1] == 2: + ax.plot(x_true[:, 0], x_true[:, 1], ".", label="True", color=PAL[0], **PLOT_KWS) + ax.plot( + x_train[:, 0], + x_train[:, 1], + ".", + label="Measured", + color=PAL[1], + **PLOT_KWS, + ) + if np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12: + ax.plot( + x_smooth[:, 0], + x_smooth[:, 1], + ".", + label="Smoothed", + color=PAL[2], + **PLOT_KWS, + ) + if labels: + ax.set(xlabel="$x_0$", ylabel="$x_1$") + elif x_train.shape[1] == 3: + ax.plot( + x_true[:, 0], + x_true[:, 1], + x_true[:, 2], + color=PAL[0], + label="True values", + **PLOT_KWS, + ) + + ax.plot( + x_train[:, 0], + x_train[:, 1], + x_train[:, 2], + ".", + color=PAL[1], + label="Measured values", + alpha=0.3, + ) + if np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12: + ax.plot( + x_smooth[:, 0], + x_smooth[:, 1], + x_smooth[:, 2], + ".", + color=PAL[2], + label="Smoothed values", + alpha=0.3, + ) + if labels: + ax.set(xlabel="$x$", ylabel="$y$", zlabel="$z$") + else: + raise ValueError("Can only plot 2d or 3d data.") + + +def plot_training_data(x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.ndarray): + """Plot training data (and smoothed training data, if different).""" + fig = plt.figure(figsize=(12, 6)) + if x_train.shape[-1] == 2: + ax0 = fig.add_subplot(1, 2, 1) + elif x_train.shape[-1] == 3: + ax0 = fig.add_subplot(1, 2, 1, projection="3d") + plot_training_trajectory(ax0, x_train, x_true, x_smooth) + ax0.legend() + ax0.set(title="Training data") + ax1 = fig.add_subplot(1, 2, 2) + ax1.loglog(np.abs(scipy.fft.rfft(x_train, axis=0)) / np.sqrt(len(x_train))) + ax1.set(title="Training Data Absolute Spectral Density") + ax1.set(xlabel="Wavenumber") + ax1.set(ylabel="Magnitude") + return fig + + +def plot_pde_training_data(last_train, last_train_true, smoothed_last_train): + """Plot training data (and smoothed training data, if different).""" + # 1D: + if len(last_train.shape) == 3: + fig, axs = plt.subplots(1, 3, figsize=(18, 6)) + axs[0].imshow(last_train_true, vmin=0, vmax=last_train_true.max()) + axs[0].set(title="True Data") + axs[1].imshow(last_train_true - last_train, vmin=0, vmax=last_train_true.max()) + axs[1].set(title="Noise") + axs[2].imshow( + last_train_true - smoothed_last_train, vmin=0, vmax=last_train_true.max() + ) + axs[2].set(title="Smoothed Data") + return plt.show() + + +def plot_test_sim_data_1d_panel( + axs: Sequence[plt.Axes], + x_test: np.ndarray, + x_sim: np.ndarray, + t_test: np.ndarray, + t_sim: np.ndarray, +) -> None: + for ordinate, ax in enumerate(axs): + ax.plot(t_test, x_test[:, ordinate], "k", label="true trajectory") + axs[ordinate].plot(t_sim, x_sim[:, ordinate], "r--", label="model simulation") + axs[ordinate].legend() + axs[ordinate].set(xlabel="t", ylabel="$x_{}$".format(ordinate)) + + +def _plot_test_sim_data_2d( + axs: Annotated[Sequence[plt.Axes], "len=2"], + x_test: np.ndarray, + x_sim: np.ndarray, + labels: bool = True, +) -> None: + axs[0].plot(x_test[:, 0], x_test[:, 1], "k", label="True Trajectory") + if labels: + axs[0].set(xlabel="$x_0$", ylabel="$x_1$") + axs[1].plot(x_sim[:, 0], x_sim[:, 1], "r--", label="Simulation") + if labels: + axs[1].set(xlabel="$x_0$", ylabel="$x_1$") + + +def _plot_test_sim_data_3d( + axs: Annotated[Sequence[plt.Axes], "len=3"], + x_test: np.ndarray, + x_sim: np.ndarray, + labels: bool = True, +) -> None: + axs[0].plot(x_test[:, 0], x_test[:, 1], x_test[:, 2], "k", label="True Trajectory") + if labels: + axs[0].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") + axs[1].plot(x_sim[:, 0], x_sim[:, 1], x_sim[:, 2], "r--", label="Simulation") + if labels: + axs[1].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") + + +def plot_test_trajectories( + x_test: np.ndarray, x_sim: np.ndarray, t_test: np.ndarray, t_sim: np.ndarray +) -> Mapping[str, np.ndarray]: + """Plot a test trajectory + + Args: + last_test: a single trajectory of the system + model: a trained model to simulate and compare to test data + dt: the time interval in test data + + Returns: + A dict with two keys, "t_sim" (the simulation times) and + "x_sim" (the simulated trajectory) + """ + fig, axs = plt.subplots(x_test.shape[1], 1, sharex=True, figsize=(7, 9)) + plt.suptitle("Test Trajectories by Dimension") + plot_test_sim_data_1d_panel(axs, x_test, x_sim, t_test, t_sim) + axs[-1].legend() + + plt.suptitle("Full Test Trajectories") + if x_test.shape[1] == 2: + fig, axs = plt.subplots(1, 2, figsize=(10, 4.5)) + _plot_test_sim_data_2d(axs, x_test, x_sim) + elif x_test.shape[1] == 3: + fig, axs = plt.subplots( + 1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"} + ) + _plot_test_sim_data_3d(axs, x_test, x_sim) + else: + raise ValueError("Can only plot 2d or 3d data.") + axs[0].set(title="true trajectory") + axs[1].set(title="model simulation") diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index d82e17f..0bc4da2 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -1,42 +1,17 @@ from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from itertools import chain -from math import ceil -from pathlib import Path from types import EllipsisType as ellipsis -from typing import ( - Annotated, - Any, - Callable, - Collection, - Literal, - Mapping, - Optional, - Sequence, - TypedDict, - TypeVar, -) +from typing import Annotated, Any, Collection, Optional, Sequence, TypedDict, TypeVar from warnings import warn import auto_ks as aks import kalman -import matplotlib.pyplot as plt -import mitosis import numpy as np import pysindy as ps -import scipy -import seaborn as sns import sklearn -from matplotlib.axes._axes import Axes -from matplotlib.figure import Figure -from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec, SubplotSpec from numpy.typing import DTypeLike, NDArray -INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12} -PAL = sns.color_palette("Set1") -PLOT_KWS = {"alpha": 0.7, "linewidth": 3} -TRIALS_FOLDER = Path(__file__).parent.absolute() / "trials" - class TrialData(TypedDict): dt: float @@ -89,245 +64,6 @@ class GridsearchResultDetails(TypedDict): main: float -@dataclass(frozen=True) -class _PlotPrefs: - """Control which gridsearch data gets plotted, and a bit of how - - Args: - plot: whether to plot - rel_noise: Whether and how to convert true noise into relative noise - grid_params_match: dictionaries of parameters to match when plotted. OR - is applied across the collection - grid_ind_match: indexing tuple to match indices in a single series - gridsearch. Only positive integers are allowed, except the first - element may be slice(None). Alternatively, ellipsis to match all - indices - """ - - plot: bool = True - rel_noise: Literal[False] | Callable = False - grid_params_match: Collection[dict] = field(default_factory=lambda: ()) - grid_ind_match: Collection[tuple[int | slice, int]] | ellipsis = field( - default_factory=lambda: ... - ) - - def __bool__(self): - return self.plot - - -def gen_data( - rhs_func, - n_coord, - seed=None, - n_trajectories=1, - x0_center=None, - ic_stdev=3, - noise_abs=None, - noise_rel=None, - nonnegative=False, - dt=0.01, - t_end=10, -): - """Generate random training and test data - - Note that test data has no noise. - - Arguments: - rhs_func (Callable): the function to integrate - n_coord (int): number of coordinates needed for rhs_func - seed (int): the random seed for number generation - n_trajectories (int): number of trajectories of training data - x0_center (np.array): center of random initial conditions - ic_stdev (float): standard deviation for generating initial - conditions - noise_abs (float): measurement noise standard deviation. - Defaults to .1 if noise_rel is None. - noise_rel (float): measurement noise-to-signal power ratio. - Either noise_abs or noise_rel must be None. Defaults to - None. - nonnegative (bool): Whether x0 must be nonnegative, such as for - population models. If so, a gamma distribution is - used, rather than a normal distribution. - - Returns: - dt, t_train, x_train, x_test, x_dot_test, x_train_true - """ - if noise_abs is not None and noise_rel is not None: - raise ValueError("Cannot specify both noise_abs and noise_rel") - elif noise_abs is None and noise_rel is None: - noise_abs = 0.1 - rng = np.random.default_rng(seed) - if x0_center is None: - x0_center = np.zeros((n_coord)) - t_train = np.arange(0, t_end, dt) - t_train_span = (t_train[0], t_train[-1]) - if nonnegative: - shape = ((x0_center + 1) / ic_stdev) ** 2 - scale = ic_stdev**2 / (x0_center + 1) - x0_train = np.array( - [rng.gamma(k, theta, n_trajectories) for k, theta in zip(shape, scale)] - ).T - x0_test = np.array([ - rng.gamma(k, theta, ceil(n_trajectories / 2)) - for k, theta in zip(shape, scale) - ]).T - else: - x0_train = ic_stdev * rng.standard_normal((n_trajectories, n_coord)) + x0_center - x0_test = ( - ic_stdev * rng.standard_normal((ceil(n_trajectories / 2), n_coord)) - + x0_center - ) - x_train = [] - for traj in range(n_trajectories): - x_train.append( - scipy.integrate.solve_ivp( - rhs_func, - t_train_span, - x0_train[traj, :], - t_eval=t_train, - **INTEGRATOR_KEYWORDS, - ).y.T - ) - - def _drop_and_warn(arrs): - maxlen = max(arr.shape[0] for arr in arrs) - - def _alert_short(arr): - if arr.shape[0] < maxlen: - warn(message="Dropping simulation due to blow-up") - return False - return True - - arrs = list(filter(_alert_short, arrs)) - if len(arrs) == 0: - raise ValueError( - "Simulations failed due to blow-up. System is too stiff for solver's" - " numerical tolerance" - ) - return arrs - - x_train = _drop_and_warn(x_train) - x_train = np.stack(x_train) - x_test = [] - for traj in range(ceil(n_trajectories / 2)): - x_test.append( - scipy.integrate.solve_ivp( - rhs_func, - t_train_span, - x0_test[traj, :], - t_eval=t_train, - **INTEGRATOR_KEYWORDS, - ).y.T - ) - x_test = _drop_and_warn(x_test) - x_test = np.array(x_test) - x_dot_test = np.array([[rhs_func(0, xij) for xij in xi] for xi in x_test]) - x_train_true = np.copy(x_train) - if noise_rel is not None: - noise_abs = np.sqrt(_signal_avg_power(x_test) * noise_rel) - x_train = x_train + noise_abs * rng.standard_normal(x_train.shape) - x_train = list(x_train) - x_test = list(x_test) - x_dot_test = list(x_dot_test) - return dt, t_train, x_train, x_test, x_dot_test, x_train_true - - -def gen_pde_data( - rhs_func: Callable, - init_cond: np.ndarray, - args: tuple, - dimension: int, - seed: int | None = None, - noise_abs: float | None = None, - noise_rel: float | None = None, - dt: float = 0.01, - t_end: int = 100, -): - """Generate PDE measurement data for training - - For simplicity, Trajectories have been removed, - Test data is the same as Train data. - - Arguments: - rhs_func: the function to integrate - init_cond: Initial Conditions for the PDE - args: Arguments for rhsfunc - dimension: Number of spatial dimensions (1, 2, or 3) - seed (int): the random seed for number generation - noise_abs (float): measurement noise standard deviation. - Defaults to .1 if noise_rel is None. - noise_rel (float): measurement noise relative to amplitude of - true data. Amplitude of data is calculated as the max value - of the power spectrum. Either noise_abs or noise_rel must - be None. Defaults to None. - dt (float): time step for the PDE simulation - t_end (int): total time for the PDE simulation - - Returns: - dt, t_train, x_train, x_test, x_dot_test, x_train_true - """ - if noise_abs is not None and noise_rel is not None: - raise ValueError("Cannot specify both noise_abs and noise_rel") - elif noise_abs is None and noise_rel is None: - noise_abs = 0.1 - rng = np.random.default_rng(seed) - t_train = np.arange(0, t_end, dt) - t_train_span = (t_train[0], t_train[-1]) - x_train = [] - x_train.append( - scipy.integrate.solve_ivp( - rhs_func, - t_train_span, - init_cond, - t_eval=t_train, - args=args, - **INTEGRATOR_KEYWORDS, - ).y.T - ) - t, x = x_train[0].shape - x_train = np.stack(x_train, axis=-1) - if dimension == 1: - pass - elif dimension == 2: - x_train = np.reshape(x_train, (t, int(np.sqrt(x)), int(np.sqrt(x)), 1)) - elif dimension == 3: - x_train = np.reshape( - x_train, (t, int(np.cbrt(x)), int(np.cbrt(x)), int(np.cbrt(x)), 1) - ) - x_test = x_train - x_test = np.moveaxis(x_test, -1, 0) - x_dot_test = np.array( - [[rhs_func(0, xij, args[0], args[1]) for xij in xi] for xi in x_test] - ) - if dimension == 1: - x_dot_test = [np.moveaxis(x_dot_test, [0, 1], [-1, -2])] - pass - elif dimension == 2: - x_dot_test = np.reshape(x_dot_test, (t, int(np.sqrt(x)), int(np.sqrt(x)), 1)) - x_dot_test = [np.moveaxis(x_dot_test, 0, -2)] - elif dimension == 3: - x_dot_test = np.reshape( - x_dot_test, (t, int(np.cbrt(x)), int(np.cbrt(x)), int(np.cbrt(x)), 1) - ) - x_dot_test = [np.moveaxis(x_dot_test, 0, -2)] - x_train_true = np.copy(x_train) - if noise_rel is not None: - noise_abs = _max_amplitude(x_test) * noise_rel - x_train = x_train + noise_abs * rng.standard_normal(x_train.shape) - x_train = [np.moveaxis(x_train, 0, -2)] - x_train_true = np.moveaxis(x_train_true, 0, -2) - x_test = [np.moveaxis(x_test, [0, 1], [-1, -2])] - return dt, t_train, x_train, x_test, x_dot_test, x_train_true - - -def _max_amplitude(signal: np.ndarray): - return np.abs(scipy.fft.rfft(signal, axis=0)[1:]).max() / np.sqrt(len(signal)) - - -def _signal_avg_power(signal: np.ndarray) -> float: - return np.square(signal).mean() - - def diff_lookup(kind): normalized_kind = kind.lower().replace(" ", "") if normalized_kind == "finitedifference": @@ -370,89 +106,6 @@ def opt_lookup(kind): raise ValueError -def plot_coefficients( - coefficients: Annotated[np.ndarray, "(n_coord, n_features)"], - input_features: Sequence[str] = None, - feature_names: Sequence[str] = None, - ax: bool = None, - **heatmap_kws, -): - if input_features is None: - input_features = [r"$\dot x_" + f"{k}$" for k in range(coefficients.shape[0])] - else: - input_features = [r"$\dot " + f"{fi}$" for fi in input_features] - - if feature_names is None: - feature_names = [f"f{k}" for k in range(coefficients.shape[1])] - - with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}): - if ax is None: - fig, ax = plt.subplots(1, 1) - - heatmap_args = { - "xticklabels": input_features, - "yticklabels": feature_names, - "center": 0.0, - "cmap": sns.color_palette("vlag", n_colors=20, as_cmap=True), - "ax": ax, - "linewidths": 0.1, - "linecolor": "whitesmoke", - } - heatmap_args.update(**heatmap_kws) - - sns.heatmap(coefficients.T, **heatmap_args) - - ax.tick_params(axis="y", rotation=0) - - return ax - - -def compare_coefficient_plots( - coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"], - coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"], - input_features: Sequence[str] = None, - feature_names: Sequence[str] = None, -): - """Create plots of true and estimated coefficients.""" - n_cols = len(coefficients_est) - - # helps boost the color of small coefficients. Maybe log is better? - def signed_sqrt(x): - return np.sign(x) * np.sqrt(np.abs(x)) - - with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}): - fig, axs = plt.subplots( - 1, 2, figsize=(1.9 * n_cols, 8), sharey=True, sharex=True - ) - - max_clean = max(np.max(np.abs(c)) for c in coefficients_est) - max_noisy = max(np.max(np.abs(c)) for c in coefficients_true) - max_mag = np.sqrt(max(max_clean, max_noisy)) - - plot_coefficients( - signed_sqrt(coefficients_true), - input_features=input_features, - feature_names=feature_names, - ax=axs[0], - cbar=False, - vmax=max_mag, - vmin=-max_mag, - ) - - plot_coefficients( - signed_sqrt(coefficients_est), - input_features=input_features, - feature_names=feature_names, - ax=axs[1], - cbar=False, - ) - - axs[0].set_title("True Coefficients", rotation=45) - axs[1].set_title("Est. Coefficients", rotation=45) - - fig.tight_layout() - - def coeff_metrics(coefficients, coeff_true): metrics = {} metrics["coeff_precision"] = sklearn.metrics.precision_score( @@ -568,146 +221,6 @@ def finalize_param(lookup_func, pdict, lookup_key): ) -def plot_training_trajectory( - ax: plt.Axes, - x_train: np.ndarray, - x_true: np.ndarray, - x_smooth: np.ndarray, - labels: bool = True, -) -> None: - """Plot a single training trajectory""" - if x_train.shape[1] == 2: - ax.plot(x_true[:, 0], x_true[:, 1], ".", label="True", color=PAL[0], **PLOT_KWS) - ax.plot( - x_train[:, 0], - x_train[:, 1], - ".", - label="Measured", - color=PAL[1], - **PLOT_KWS, - ) - if np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12: - ax.plot( - x_smooth[:, 0], - x_smooth[:, 1], - ".", - label="Smoothed", - color=PAL[2], - **PLOT_KWS, - ) - if labels: - ax.set(xlabel="$x_0$", ylabel="$x_1$") - elif x_train.shape[1] == 3: - ax.plot( - x_true[:, 0], - x_true[:, 1], - x_true[:, 2], - color=PAL[0], - label="True values", - **PLOT_KWS, - ) - - ax.plot( - x_train[:, 0], - x_train[:, 1], - x_train[:, 2], - ".", - color=PAL[1], - label="Measured values", - alpha=0.3, - ) - if np.linalg.norm(x_smooth - x_train) / x_smooth.size > 1e-12: - ax.plot( - x_smooth[:, 0], - x_smooth[:, 1], - x_smooth[:, 2], - ".", - color=PAL[2], - label="Smoothed values", - alpha=0.3, - ) - if labels: - ax.set(xlabel="$x$", ylabel="$y$", zlabel="$z$") - else: - raise ValueError("Can only plot 2d or 3d data.") - - -def plot_training_data(x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.ndarray): - """Plot training data (and smoothed training data, if different).""" - fig = plt.figure(figsize=(12, 6)) - if x_train.shape[-1] == 2: - ax0 = fig.add_subplot(1, 2, 1) - elif x_train.shape[-1] == 3: - ax0 = fig.add_subplot(1, 2, 1, projection="3d") - plot_training_trajectory(ax0, x_train, x_true, x_smooth) - ax0.legend() - ax0.set(title="Training data") - ax1 = fig.add_subplot(1, 2, 2) - ax1.loglog(np.abs(scipy.fft.rfft(x_train, axis=0)) / np.sqrt(len(x_train))) - ax1.set(title="Training Data Absolute Spectral Density") - ax1.set(xlabel="Wavenumber") - ax1.set(ylabel="Magnitude") - return fig - - -def plot_pde_training_data(last_train, last_train_true, smoothed_last_train): - """Plot training data (and smoothed training data, if different).""" - # 1D: - if len(last_train.shape) == 3: - fig, axs = plt.subplots(1, 3, figsize=(18, 6)) - axs[0].imshow(last_train_true, vmin=0, vmax=last_train_true.max()) - axs[0].set(title="True Data") - axs[1].imshow(last_train_true - last_train, vmin=0, vmax=last_train_true.max()) - axs[1].set(title="Noise") - axs[2].imshow( - last_train_true - smoothed_last_train, vmin=0, vmax=last_train_true.max() - ) - axs[2].set(title="Smoothed Data") - return plt.show() - - -def plot_test_sim_data_1d_panel( - axs: Sequence[plt.Axes], - x_test: np.ndarray, - x_sim: np.ndarray, - t_test: np.ndarray, - t_sim: np.ndarray, -) -> None: - for ordinate, ax in enumerate(axs): - ax.plot(t_test, x_test[:, ordinate], "k", label="true trajectory") - axs[ordinate].plot(t_sim, x_sim[:, ordinate], "r--", label="model simulation") - axs[ordinate].legend() - axs[ordinate].set(xlabel="t", ylabel="$x_{}$".format(ordinate)) - - -def _plot_test_sim_data_2d( - axs: Annotated[Sequence[plt.Axes], "len=2"], - x_test: np.ndarray, - x_sim: np.ndarray, - labels: bool = True, -) -> None: - axs[0].plot(x_test[:, 0], x_test[:, 1], "k", label="True Trajectory") - if labels: - axs[0].set(xlabel="$x_0$", ylabel="$x_1$") - axs[1].plot(x_sim[:, 0], x_sim[:, 1], "r--", label="Simulation") - if labels: - axs[1].set(xlabel="$x_0$", ylabel="$x_1$") - - -def _plot_test_sim_data_3d( - axs: Annotated[Sequence[plt.Axes], "len=3"], - x_test: np.ndarray, - x_sim: np.ndarray, - labels: bool = True, -) -> None: - axs[0].plot(x_test[:, 0], x_test[:, 1], x_test[:, 2], "k", label="True Trajectory") - if labels: - axs[0].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") - axs[1].plot(x_sim[:, 0], x_sim[:, 1], x_sim[:, 2], "r--", label="Simulation") - if labels: - axs[1].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") - - def simulate_test_data(model: ps.SINDy, dt: float, x_test: np.ndarray) -> TrialData: """Add simulation data to grid_data @@ -727,40 +240,6 @@ def simulate_test_data(model: ps.SINDy, dt: float, x_test: np.ndarray) -> TrialD return {"t_sim": t_sim, "x_sim": x_sim, "t_test": t_test} -def plot_test_trajectories( - x_test: np.ndarray, x_sim: np.ndarray, t_test: np.ndarray, t_sim: np.ndarray -) -> Mapping[str, np.ndarray]: - """Plot a test trajectory - - Args: - last_test: a single trajectory of the system - model: a trained model to simulate and compare to test data - dt: the time interval in test data - - Returns: - A dict with two keys, "t_sim" (the simulation times) and - "x_sim" (the simulated trajectory) - """ - fig, axs = plt.subplots(x_test.shape[1], 1, sharex=True, figsize=(7, 9)) - plt.suptitle("Test Trajectories by Dimension") - plot_test_sim_data_1d_panel(axs, x_test, x_sim, t_test, t_sim) - axs[-1].legend() - - plt.suptitle("Full Test Trajectories") - if x_test.shape[1] == 2: - fig, axs = plt.subplots(1, 2, figsize=(10, 4.5)) - _plot_test_sim_data_2d(axs, x_test, x_sim) - elif x_test.shape[1] == 3: - fig, axs = plt.subplots( - 1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"} - ) - _plot_test_sim_data_3d(axs, x_test, x_sim) - else: - raise ValueError("Can only plot 2d or 3d data.") - axs[0].set(title="true trajectory") - axs[1].set(title="model simulation") - - @dataclass class SeriesDef: """The details of constructing the ragged axes of a grid search. @@ -934,15 +413,6 @@ def proj(curr_params, t): return est_alpha -def load_results(hexstr: str) -> GridsearchResultDetails: - """Load the results that mitosis saves - - Args: - hexstr: randomly-assigned identifier for the results to open - """ - return mitosis.load_trial_data(hexstr, trials_folder=TRIALS_FOLDER) - - def _amax_to_full_inds( amax_inds: Collection[tuple[int | slice, int] | ellipsis], amax_arrays: list[list[GridsearchResult]], @@ -974,245 +444,6 @@ def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: return all_inds -def _setup_summary_fig( - n_sub: int, *, fig_cell: Optional[tuple[Figure, SubplotSpec]] = None -) -> tuple[Figure, GridSpec | GridSpecFromSubplotSpec]: - """Create neatly laid-out arrangements for subplots - - Creates an evenly-spaced gridpsec to fit follow-on plots and a - figure, if required. - - Args: - n_sub: number of grid elements to create - nest_parent: parent grid cell within which to to build a nested - gridspec - Returns: - a figure and gridspec if nest_parent is not provided, otherwise, - None and a sub-gridspec - """ - n_rows = max(n_sub // 3, (n_sub + 2) // 3) - n_cols = min(n_sub, 3) - figsize = [3 * n_cols, 3 * n_rows] - if fig_cell is None: - fig = plt.figure(figsize=figsize) - gs = fig.add_gridspec(n_rows, n_cols) - return fig, gs - fig, cell = fig_cell - return fig, cell.subgridspec(n_rows, n_cols) - - -def plot_experiment_across_gridpoints( - hexstr: str, - *args: tuple[str, dict] | ellipsis | tuple[int | slice, int], - style: str, - fig_cell: tuple[Figure, SubplotSpec] = None, - annotations: bool = True, -) -> tuple[Figure, Sequence[str]]: - """Plot a single experiment's test across multiple gridpoints - - Arguments: - hexstr: hexadecimal suffix for the experiment's result file. - args: From which gridpoints to load data, described either as: - - a local name and the parameters defining the gridpoint to match. - - ellipsis, indicating optima across all metrics across all plot - axes - - an indexing tuple indicating optima for that tuple's location in - the gridsearch argmax array - Matching logic is AND(OR(parameter matches), OR(index matches)) - style: either "test" or "train" - Returns: - the plotted figure - """ - - fig, gs = _setup_summary_fig(len(args), fig_cell=fig_cell) - if fig_cell is not None: - fig.suptitle("How do different smoothing compare on an ODE?") - p_names = [] - results = load_results(hexstr) - amax_arrays = [ - [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] - for _, single_series_all_axes in results["series_data"].items() - ] - parg_inds = { - argind - for argind, arg in enumerate(args) - if isinstance(arg, tuple) and isinstance(arg[0], str) - } - indarg_inds = set(range(len(args))) - parg_inds - pargs = [args[i] for i in parg_inds] - indargs = [args[i] for i in indarg_inds] - if not indargs: - indargs = {...} - full_inds = _amax_to_full_inds(indargs, amax_arrays) - - for cell, (p_name, params) in zip(gs, pargs): - for trajectory in results["plot_data"]: - if _grid_locator_match( - trajectory["params"], trajectory["pind"], [params], full_inds - ): - p_names.append(p_name) - ax = _plot_train_test_cell( - (fig, cell), trajectory, style, annotations=False - ) - if annotations: - ax.set_title(p_name) - break - else: - warn(f"Did not find a parameter match for {p_name} experiment") - if annotations: - ax.legend() - return Figure, p_names - - -def _plot_train_test_cell( - fig_cell: tuple[Figure, SubplotSpec | int | tuple[int, int, int]], - trajectory: SavedData, - style: str, - annotations: bool = False, -) -> Axes: - """Plot either the training or test data in a single cell""" - fig, cell = fig_cell - if trajectory["data"]["x_test"].shape[1] == 2: - ax = fig.add_subplot(cell) - plot_func = _plot_test_sim_data_2d - else: - ax = fig.add_subplot(cell, projection="3d") - plot_func = _plot_test_sim_data_3d - if style.lower() == "training": - plot_func = plot_training_trajectory - plot_location = ax - data = ( - trajectory["data"]["x_train"], - trajectory["data"]["x_true"], - trajectory["data"]["smooth_train"], - ) - elif style.lower() == "test": - plot_location = [ax, ax] - data = ( - trajectory["data"]["x_test"], - trajectory["data"]["x_sim"], - ) - plot_func(plot_location, *data, labels=annotations) - return ax - - -def plot_point_across_experiments( - params: dict, - point: ellipsis | tuple[int | slice, int] = ..., - *args: tuple[str, str], - style: str, -) -> Figure: - """Plot a single parameter's training or test across multiple experiments - - Arguments: - params: parameters defining the gridpoint to match - point: gridpoint spec from the argmax array, defined as either an - - ellipsis, indicating optima across all metrics across all plot - axes - - indexing tuple indicating optima for that tuple's location in - the gridsearch argmax array - args (experiment_name, hexstr): From which experiments to load - data, described as a local name and the hexadecimal suffix - of the result file. - style: either "test" or "train" - Returns: - the plotted figure - """ - fig, gs = _setup_summary_fig(len(args)) - fig.suptitle("How well does a smoothing method perform across ODEs?") - - for cell, (ode_name, hexstr) in zip(gs, args): - results = load_results(hexstr) - amax_arrays = [ - [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] - for _, single_series_all_axes in results["series_data"].items() - ] - full_inds = _amax_to_full_inds((point,), amax_arrays) - for trajectory in results["plot_data"]: - if _grid_locator_match( - trajectory["params"], trajectory["pind"], [params], full_inds - ): - ax = _plot_train_test_cell( - [fig, cell], trajectory, style, annotations=False - ) - ax.set_title(ode_name) - break - else: - warn(f"Did not find a parameter match for {ode_name} experiment") - ax.legend() - return fig - - -def plot_summary_metric( - metric: str, grid_axis_name: tuple[str, Collection], *args: tuple[str, str] -) -> None: - """After multiple gridsearches, plot a comparison for all ODEs - - Plots the overall results for a single metric, single grid axis - Args: - metric: which metric is being plotted - grid_axis: the name of the parameter varied and the values of - the parameter. - *args: each additional tuple contains the name of an ODE and - the hexstr under which it's data is saved. - """ - fig, gs = _setup_summary_fig(len(args)) - fig.suptitle( - f"How well do the methods work on different ODEs as {grid_axis_name} changes?" - ) - for cell, (ode_name, hexstr) in zip(gs, args): - results = load_results(hexstr) - grid_axis_index = results["grid_params"].index(grid_axis_name) - grid_axis = results["grid_vals"][grid_axis_index] - metric_index = results["metrics"].index(metric) - ax = fig.add_subplot(cell) - for s_name, s_data in results["series_data"].items(): - ax.plot(grid_axis, s_data[grid_axis_index][0][metric_index], label=s_name) - ax.set_title(ode_name) - ax.legend() - - -def plot_summary_test_train( - exps: Sequence[tuple[str, str]], - params: Sequence[tuple[str, dict] | ellipsis | tuple[int | slice, int]], - style: str, -) -> None: - """Plot a comparison of different variants across experiments - - Args: - exps: From which experiments to load data, described as a local name - and the hexadecimal suffix of the result file. - params: which gridpoints to compare, described as either: - - a tuple of local name and parameters to match. - - ellipsis, indicating optima across all metrics across all plot - axes - - an indexing tuple indicating optima for that tuple's location in - the gridsearch argmax array - Matching logic is AND(OR(parameter matches), OR(index matches)) - style - """ - n_exp = len(exps) - n_params = len(params) - figsize = (3 * n_params, 3 * n_exp) - fig = plt.figure(figsize=figsize) - grid = fig.add_gridspec(n_exp, 2, width_ratios=(1, 20)) - for n_row, (ode_name, hexstr) in enumerate(exps): - cell = grid[n_row, 1] - _, p_names = plot_experiment_across_gridpoints( - hexstr, *params, style=style, fig_cell=(fig, cell), annotations=False - ) - empty_ax = fig.add_subplot(grid[n_row, 0]) - empty_ax.axis("off") - empty_ax.text( - -0.1, 0.5, ode_name, va="center", transform=empty_ax.transAxes, rotation=90 - ) - first_row = fig.get_axes()[:n_params] - for ax, p_name in zip(first_row, p_names): - ax.set_title(p_name) - fig.subplots_adjust(top=0.95) - return fig - - def _argopt( arr: np.ndarray, axis: int | tuple[int, ...] = None, opt: str = "max" ) -> np.ndarray[tuple[int, ...]]: @@ -1318,7 +549,7 @@ def _grid_locator_match( return found_match -def _strict_find_grid_match( +def strict_find_grid_match( results: GridsearchResultDetails, *, params: Optional[dict[str, Any]] = None, @@ -1356,14 +587,3 @@ def _param_normalize(val: _EqTester) -> _EqTester | str: return repr(val) else: return val - - -# class GridMatch: -# bad_keys: dict -# truth: bool -# def __init__(self, truth: bool): -# self.truth = truth -# self.bad_keys = {} - -# def __bool__(self) -> bool: -# return not self.bad_keys From 27d32fe33610250e1c37e9cea21412581481dfd1 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 14 Feb 2024 17:27:35 +0000 Subject: [PATCH 04/46] bug: allow coefficient plots with matplotlib rc useTex --- src/gen_experiments/odes.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index a054b2d..8d0cc29 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -222,8 +222,8 @@ def plot_ode_panel(trial_data: FullTrialData): compare_coefficient_plots( trial_data["coeff_fit"], trial_data["coeff_true"], - input_features=trial_data["input_features"], - feature_names=trial_data["feature_names"], + input_features=[_texify(feat) for feat in trial_data["input_features"]], + feature_names=[_texify(feat) for feat in trial_data["feature_names"]], ) plot_test_trajectories( trial_data["x_test"], @@ -232,3 +232,11 @@ def plot_ode_panel(trial_data: FullTrialData): trial_data["t_sim"], ) plt.show() + + +def _texify(input: str) -> str: + if input[0] != "$": + input = "$" + input + if input[-1] != "$": + input = input + "$" + return input From 483fdc8da8492e979495dfd9e3440106f0700f05 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:07:25 +0000 Subject: [PATCH 05/46] bug(plots): Prevent tex errors in matplotlib usetex --- src/gen_experiments/plotting.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 4a54230..1c71180 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -44,10 +44,17 @@ def plot_coefficients( ax: bool = None, **heatmap_kws, ): + def detex(input: str) -> str: + if input[0] == "$": + input = input[1:] + if input[-1] == "$": + input = input[:-1] + return input + if input_features is None: input_features = [r"$\dot x_" + f"{k}$" for k in range(coefficients.shape[0])] else: - input_features = [r"$\dot " + f"{fi}$" for fi in input_features] + input_features = [r"$\dot " + f"{detex(fi)}$" for fi in input_features] if feature_names is None: feature_names = [f"f{k}" for k in range(coefficients.shape[1])] From 1f142df145910e0ccedf76b4981be0de5f5c657c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:08:13 +0000 Subject: [PATCH 06/46] bld: upgrade mitosis dependencies --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a64b84a..324197d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ # Since the point of the package is reproducibility, incl. all dev # dependencies dependencies = [ - "mitosis >=0.3.0rc1, <0.4.0", + "mitosis >=0.4.0rc2", "derivative @ git+https://github.com/Jacob-Stevens-Haas/derivative@hyperparams", "pysindy[cvxpy,miosr] @ git+https://github.com/dynamicslab/pysindy@master", "kalman @ git+https://github.com/Jacob-Stevens-Haas/kalman@0.1.0", @@ -74,7 +74,7 @@ extend-exclude = ''' | env )/ ''' -preview = 1 +preview = true [tool.codespell] skip = '*.html,./env,./scratch/*,todo' From 2a94de6a6ea774a6006d2ab71463ecfb811d6512 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 17 Feb 2024 18:54:41 +0000 Subject: [PATCH 07/46] type(utils): Specify that TrialData is SINDyTrialData, and type out amax func --- pyproject.toml | 2 +- src/gen_experiments/__init__.py | 6 +-- src/gen_experiments/config.py | 4 +- src/gen_experiments/data.py | 2 +- src/gen_experiments/gridsearch.py | 4 +- src/gen_experiments/odes.py | 12 ++--- src/gen_experiments/pdes.py | 10 ++-- src/gen_experiments/utils.py | 76 ++++++++++++++++++++++--------- 8 files changed, 74 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 324197d..d2613b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "pysindy-experiments" dynamic = ["version"] description = "My general exam experiments" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" license = {file = "LICENSE"} keywords = ["Machine Learning", "Science", "Mathematics", "Experiments"] authors = [ diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index 7f38c6b..fbffc64 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -8,7 +8,7 @@ from pysindy import BaseDifferentiation, FiniteDifference, SINDy # type: ignore from . import gridsearch, odes, pdes -from .utils import TrialData +from .utils import SINDyTrialData this_module = importlib.import_module(__name__) BORING_ARRAY = np.ones((2, 2)) @@ -38,13 +38,13 @@ class NoExperiment: @staticmethod def run( *args: Any, return_all: bool = True, **kwargs: Any - ) -> Scores | tuple[Scores, TrialData]: + ) -> Scores | tuple[Scores, SINDyTrialData]: metrics = defaultdict( lambda: 1, main=1, ) if return_all: - trial_data: TrialData = { + trial_data: SINDyTrialData = { "dt": 1, "coeff_true": BORING_ARRAY[:1], "coeff_fit": BORING_ARRAY[:1], diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 9efb66f..2e9137c 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -5,7 +5,7 @@ from gen_experiments.data import _signal_avg_power from gen_experiments.plotting import _PlotPrefs -from gen_experiments.utils import FullTrialData, NestedDict, SeriesDef, SeriesList +from gen_experiments.utils import FullSINDyTrialData, NestedDict, SeriesDef, SeriesList T = TypeVar("T") U = TypeVar("U") @@ -16,7 +16,7 @@ def ND(d: dict[T, U]) -> NestedDict[T, U]: def _convert_abs_rel_noise( - grid_vals: list, grid_params: list, recent_results: FullTrialData + grid_vals: list, grid_params: list, recent_results: FullSINDyTrialData ): """Convert abs_noise grid_vals to rel_noise""" signal = np.stack(recent_results["x_true"], axis=-1) diff --git a/src/gen_experiments/data.py b/src/gen_experiments/data.py index 51fc6e4..a3091ad 100644 --- a/src/gen_experiments/data.py +++ b/src/gen_experiments/data.py @@ -218,7 +218,7 @@ def gen_pde_data( return dt, t_train, x_train, x_test, x_dot_test, x_train_true -def _max_amplitude(signal: np.ndarray): +def _max_amplitude(signal: np.ndarray) -> float: return np.abs(scipy.fft.rfft(signal, axis=0)[1:]).max() / np.sqrt(len(signal)) diff --git a/src/gen_experiments/gridsearch.py b/src/gen_experiments/gridsearch.py index 676ad49..e7aa488 100644 --- a/src/gen_experiments/gridsearch.py +++ b/src/gen_experiments/gridsearch.py @@ -20,7 +20,7 @@ SavedData, SeriesDef, SeriesList, - TrialData, + SINDyTrialData, _amax_to_full_inds, _argopt, _grid_locator_match, @@ -115,7 +115,7 @@ def run( curr_results, grid_data = base_ex.run( new_seed, **curr_other_params, display=False, return_all=True ) - grid_data: TrialData + grid_data: SINDyTrialData intermediate_data.append( {"params": curr_other_params.flatten(), "pind": ind, "data": grid_data} ) diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index 8d0cc29..2f3acf3 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -12,8 +12,8 @@ plot_training_data, ) from .utils import ( - FullTrialData, - TrialData, + FullSINDyTrialData, + SINDyTrialData, _make_model, coeff_metrics, integration_metrics, @@ -161,7 +161,7 @@ def run( opt_params: dict, display: bool = True, return_all: bool = False, -) -> dict | tuple[dict, TrialData | FullTrialData]: +) -> dict | tuple[dict, SINDyTrialData | FullSINDyTrialData]: rhsfunc = ode_setup[group]["rhsfunc"] input_features = ode_setup[group]["input_features"] coeff_true = ode_setup[group]["coeff_true"] @@ -187,7 +187,7 @@ def run( coeff_true, coefficients, feature_names = unionize_coeff_matrices(model, coeff_true) sim_ind = -1 - trial_data: TrialData = { + trial_data: SINDyTrialData = { "dt": dt, "coeff_true": coeff_true, "coeff_fit": coefficients, @@ -202,7 +202,7 @@ def run( "model": model, } if display: - trial_data: FullTrialData = trial_data | simulate_test_data( + trial_data: FullSINDyTrialData = trial_data | simulate_test_data( trial_data["model"], trial_data["dt"], trial_data["x_test"] ) plot_ode_panel(trial_data) @@ -214,7 +214,7 @@ def run( return metrics -def plot_ode_panel(trial_data: FullTrialData): +def plot_ode_panel(trial_data: FullSINDyTrialData): trial_data["model"].print() plot_training_data( trial_data["x_train"], trial_data["x_true"], trial_data["smooth_train"] diff --git a/src/gen_experiments/pdes.py b/src/gen_experiments/pdes.py index c7f905e..8cb712b 100644 --- a/src/gen_experiments/pdes.py +++ b/src/gen_experiments/pdes.py @@ -5,8 +5,8 @@ from .data import gen_pde_data from .plotting import compare_coefficient_plots, plot_pde_training_data from .utils import ( - FullTrialData, - TrialData, + FullSINDyTrialData, + SINDyTrialData, _make_model, coeff_metrics, integration_metrics, @@ -149,7 +149,7 @@ def run( opt_params: dict, display: bool = True, return_all: bool = False, -) -> dict | tuple[dict, TrialData | FullTrialData]: +) -> dict | tuple[dict, SINDyTrialData | FullSINDyTrialData]: rhsfunc = pde_setup[group]["rhsfunc"]["func"] input_features = pde_setup[group]["input_features"] initial_condition = sim_params["init_cond"] @@ -177,7 +177,7 @@ def run( coeff_true, coefficients, feature_names = unionize_coeff_matrices(model, coeff_true) sim_ind = -1 - trial_data: TrialData = { + trial_data: SINDyTrialData = { "dt": dt, "coeff_true": coeff_true, "coeff_fit": coefficients, @@ -192,7 +192,7 @@ def run( "model": model, } if display: - trial_data: FullTrialData = trial_data | simulate_test_data( + trial_data: FullSINDyTrialData = trial_data | simulate_test_data( trial_data["model"], trial_data["dt"], trial_data["x_test"] ) trial_data["model"].print() diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 0bc4da2..4d53aff 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -2,7 +2,16 @@ from dataclasses import dataclass from itertools import chain from types import EllipsisType as ellipsis -from typing import Annotated, Any, Collection, Optional, Sequence, TypedDict, TypeVar +from typing import ( + Annotated, + Any, + Collection, + Optional, + Sequence, + TypedDict, + TypeVar, + cast, +) from warnings import warn import auto_ks as aks @@ -10,16 +19,23 @@ import numpy as np import pysindy as ps import sklearn -from numpy.typing import DTypeLike, NDArray +import sklearn.metrics +from numpy.typing import DTypeLike, NBitBase, NDArray +NpFlt = np.dtype[np.floating[NBitBase]] +Float1D = np.ndarray[tuple[int], NpFlt] +Float2D = np.ndarray[tuple[int, int], NpFlt] +Shape = TypeVar("Shape", bound=tuple[int, ...]) +FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] -class TrialData(TypedDict): + +class SINDyTrialData(TypedDict): dt: float - coeff_true: Annotated[np.ndarray, "(n_coord, n_features)"] - coeff_fit: Annotated[np.ndarray, "(n_coord, n_features)"] + coeff_true: Annotated[Float2D, "(n_coord, n_features)"] + coeff_fit: Annotated[Float2D, "(n_coord, n_features)"] feature_names: Annotated[list[str], "length=n_features"] input_features: Annotated[list[str], "length=n_coord"] - t_train: np.ndarray + t_train: Float1D x_train: np.ndarray x_true: np.ndarray smooth_train: np.ndarray @@ -28,20 +44,25 @@ class TrialData(TypedDict): model: ps.SINDy -class FullTrialData(TrialData): - t_sim: np.ndarray +class SINDyTrialUpdate(TypedDict): + t_sim: Float1D + t_test: Float1D + x_sim: FloatND + + +class FullSINDyTrialData(SINDyTrialData): + t_sim: Float1D x_sim: np.ndarray class SavedData(TypedDict): params: dict pind: tuple[int] - data: TrialData | FullTrialData + data: SINDyTrialData | FullSINDyTrialData T = TypeVar("T", bound=np.generic) GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] # type: ignore - SeriesData = Annotated[ list[ tuple[ @@ -215,28 +236,28 @@ def finalize_param(lookup_func, pdict, lookup_key): return ps.SINDy( differentiation_method=diff, optimizer=opt, - t_default=dt, + t_default=dt, # type: ignore feature_library=features, feature_names=input_features, ) -def simulate_test_data(model: ps.SINDy, dt: float, x_test: np.ndarray) -> TrialData: +def simulate_test_data(model: ps.SINDy, dt: float, x_test: Float2D) -> SINDyTrialUpdate: """Add simulation data to grid_data This includes the t_sim and x_sim keys. Does not mutate argument. Returns: Complete GridPointData """ - t_test = np.arange(len(x_test) * dt, step=dt) + t_test = cast(Float1D, np.arange(0, len(x_test) * dt, step=dt)) t_sim = t_test try: - x_sim = model.simulate(x_test[0], t_test) + x_sim = cast(Float2D, model.simulate(x_test[0], t_test)) except ValueError: warn(message="Simulation blew up; returning zeros") x_sim = np.zeros_like(x_test) # truncate if integration returns wrong number of points - t_sim = t_test[: len(x_sim)] + t_sim = cast(Float1D, t_test[: len(x_sim)]) return {"t_sim": t_sim, "x_sim": x_sim, "t_test": t_test} @@ -350,7 +371,7 @@ def __setitem__(self, key, value): else: return super().__setitem__(key, value) - def update(self, other: dict): + def update(self, other: dict): # type: ignore try: for k, v in other.items(): self.__setitem__(k, v) @@ -414,9 +435,20 @@ def proj(curr_params, t): def _amax_to_full_inds( - amax_inds: Collection[tuple[int | slice, int] | ellipsis], - amax_arrays: list[list[GridsearchResult]], + amax_inds: Collection[tuple[int | slice, int] | ellipsis] | ellipsis, + amax_arrays: list[list[GridsearchResult[np.void]]], ) -> set[tuple[int, ...]]: + """Find full indexers to selected elements of argmax arrays + + Args: + amax_inds: selection statemtent of which argmaxes to return. + amax_arrays: arrays of indexes to full gridsearch that are responsible for + the computed max values. First level of nesting reflects series(?), second + level reflects which grid grid axis. + Returns: + all indexers to full gridsearch that are requested by amax_inds + """ + def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: return tuple(int(el) for el in tuple_like) @@ -438,15 +470,15 @@ def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: for el in arr.flatten() } elif isinstance(ind[0], int): - all_inds |= {np_to_primitive(plot_axis_results[ind])} + all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))} else: # ind[0] is slice(None) all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]} return all_inds def _argopt( - arr: np.ndarray, axis: int | tuple[int, ...] = None, opt: str = "max" -) -> np.ndarray[tuple[int, ...]]: + arr: FloatND, axis: Optional[int | tuple[int, ...]] = None, opt: str = "max" +) -> NDArray[tuple[int, ...]]: """Calculate the argmax/min, but accept tuple axis. Ignores NaN values @@ -554,7 +586,7 @@ def strict_find_grid_match( *, params: Optional[dict[str, Any]] = None, ind_spec: Optional[tuple[int | slice, int] | ellipsis] = None, -) -> TrialData: +) -> SINDyTrialData: if params is None: params = {} if ind_spec is None: From 64b7f0572990ff65edfd220892f9db094c40ab1f Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 17 Feb 2024 19:42:09 +0000 Subject: [PATCH 08/46] typing: Add mypy settings and finish types in utils --- pyproject.toml | 16 ++++++++++++++++ src/gen_experiments/utils.py | 7 +++++-- tests/test_all.py | 2 ++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d2613b7..95a2dc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,19 @@ markers = ["slow"] [tool.mypy] files = ["src/gen_experiments/__init__.py"] + +[[tool.mypy.overrides]] +module="auto_ks.*" +ignore_missing_imports=true + +[[tool.mypy.overrides]] +module="sklearn.*" +ignore_missing_imports=true + +[[tool.mypy.overrides]] +module="pysindy.*" +ignore_missing_imports=true + +[[tool.mypy.overrides]] +module="kalman.*" +ignore_missing_imports=true diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 4d53aff..7841f4d 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -1,4 +1,5 @@ from collections import defaultdict +from collections.abc import Iterable from dataclasses import dataclass from itertools import chain from types import EllipsisType as ellipsis @@ -450,7 +451,7 @@ def _amax_to_full_inds( """ def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: - return tuple(int(el) for el in tuple_like) + return tuple(int(el) for el in cast(Iterable, tuple_like)) if amax_inds is ...: # grab each element from arrays in list of lists of arrays return { @@ -478,7 +479,7 @@ def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: def _argopt( arr: FloatND, axis: Optional[int | tuple[int, ...]] = None, opt: str = "max" -) -> NDArray[tuple[int, ...]]: +) -> NDArray[np.void]: """Calculate the argmax/min, but accept tuple axis. Ignores NaN values @@ -494,6 +495,8 @@ def _argopt( tuples of length m """ dtype: DTypeLike = [(f"f{axind}", "i") for axind in range(arr.ndim)] + if axis is None: + axis = () axis = (axis,) if isinstance(axis, int) else axis keep_axes = tuple(sorted(set(range(arr.ndim)) - set(axis))) keep_shape = tuple(arr.shape[ax] for ax in keep_axes) diff --git a/tests/test_all.py b/tests/test_all.py index dd5c789..962d0df 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -96,6 +96,8 @@ def test_argopt_empty_tuple_axis(): result = utils._argopt(arr, ()) expected = np.array([(0,), (1,), (2,), (3,)], dtype=[("f0", "i")]) np.testing.assert_array_equal(result, expected) + result = utils._argopt(arr, None) + pass def test_argopt_int_axis(): From e8a75c2185710318b59fb564cfd8982dfa5dc3de Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 18 Feb 2024 15:18:26 +0000 Subject: [PATCH 09/46] ok even better typing --- pyproject.toml | 4 ++++ src/gen_experiments/data.py | 35 ++++++++++++++++++----------------- src/gen_experiments/debug.py | 25 +++++++++++++++++++++++++ src/gen_experiments/utils.py | 4 ++-- 4 files changed, 49 insertions(+), 19 deletions(-) create mode 100644 src/gen_experiments/debug.py diff --git a/pyproject.toml b/pyproject.toml index 95a2dc6..1b9ce17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,3 +109,7 @@ ignore_missing_imports=true [[tool.mypy.overrides]] module="kalman.*" ignore_missing_imports=true + +[[tool.mypy.overrides]] +module="scipy.*" +ignore_missing_imports=true diff --git a/src/gen_experiments/data.py b/src/gen_experiments/data.py index a3091ad..0d651f2 100644 --- a/src/gen_experiments/data.py +++ b/src/gen_experiments/data.py @@ -1,31 +1,31 @@ from math import ceil from pathlib import Path -from typing import Callable +from typing import Callable, Optional, cast from warnings import warn import mitosis import numpy as np import scipy -from gen_experiments.utils import GridsearchResultDetails +from gen_experiments.utils import Float1D, Float2D, GridsearchResultDetails INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12} TRIALS_FOLDER = Path(__file__).parent.absolute() / "trials" def gen_data( - rhs_func, - n_coord, - seed=None, - n_trajectories=1, - x0_center=None, - ic_stdev=3, - noise_abs=None, - noise_rel=None, - nonnegative=False, - dt=0.01, - t_end=10, -): + rhs_func: Callable, + n_coord: int, + seed: Optional[int] = None, + n_trajectories: int = 1, + x0_center: Optional[Float1D] = None, + ic_stdev: float = 3, + noise_abs: Optional[float] = None, + noise_rel: Optional[float] = None, + nonnegative: bool = False, + dt: float = 0.01, + t_end: float = 10, +) -> tuple[float, Float1D, Float2D, Float2D, Float2D, Float2D]: """Generate random training and test data Note that test data has no noise. @@ -57,7 +57,7 @@ def gen_data( rng = np.random.default_rng(seed) if x0_center is None: x0_center = np.zeros((n_coord)) - t_train = np.arange(0, t_end, dt) + t_train = np.arange(0, t_end, dt, dtype=np.float_) t_train_span = (t_train[0], t_train[-1]) if nonnegative: shape = ((x0_center + 1) / ic_stdev) ** 2 @@ -123,10 +123,11 @@ def _alert_short(arr): x_train_true = np.copy(x_train) if noise_rel is not None: noise_abs = np.sqrt(_signal_avg_power(x_test) * noise_rel) - x_train = x_train + noise_abs * rng.standard_normal(x_train.shape) + x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape) x_train = list(x_train) x_test = list(x_test) x_dot_test = list(x_dot_test) + x_train_true = list(x_train_true) return dt, t_train, x_train, x_test, x_dot_test, x_train_true @@ -211,7 +212,7 @@ def gen_pde_data( x_train_true = np.copy(x_train) if noise_rel is not None: noise_abs = _max_amplitude(x_test) * noise_rel - x_train = x_train + noise_abs * rng.standard_normal(x_train.shape) + x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape) x_train = [np.moveaxis(x_train, 0, -2)] x_train_true = np.moveaxis(x_train_true, 0, -2) x_test = [np.moveaxis(x_test, [0, 1], [-1, -2])] diff --git a/src/gen_experiments/debug.py b/src/gen_experiments/debug.py new file mode 100644 index 0000000..a1193ca --- /dev/null +++ b/src/gen_experiments/debug.py @@ -0,0 +1,25 @@ +from typing import Annotated, Generic, TypedDict, TypeVar + +import numpy as np +from numpy.typing import DTypeLike, NBitBase, NDArray + +# T = TypeVar("T") + +# class Foo[T]: +# items: list[T] + +# def __init__(self, thing: T): +# self.items = [thing, thing] + +# Bar = + + +T = TypeVar("T", bound=np.generic) +Foo = NDArray[T] +Bar = Annotated[NDArray, "foobar"] + +lil_foo = NDArray[np.void] + + +def baz(qux: Foo[np.void]): + pass diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 7841f4d..65580ee 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -63,12 +63,12 @@ class SavedData(TypedDict): T = TypeVar("T", bound=np.generic) -GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] # type: ignore +GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] SeriesData = Annotated[ list[ tuple[ Annotated[GridsearchResult, "metrics"], - Annotated[GridsearchResult, "arg_opts"], + Annotated[GridsearchResult[np.void], "arg_opts"], ] ], "len=n_grid_axes", From fa382fb77d3f126dc01a0fbf8546c97da3774319 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 18 Feb 2024 18:04:44 +0000 Subject: [PATCH 10/46] feat(plots): Also remove ticks when labels=False in train/test plots --- src/gen_experiments/plotting.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 1c71180..0aa09bd 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -6,6 +6,7 @@ import numpy as np import scipy import seaborn as sns +from matplotlib.axes import Axes PAL = sns.color_palette("Set1") PLOT_KWS = {"alpha": 0.7, "linewidth": 3} @@ -128,7 +129,7 @@ def signed_sqrt(x): def plot_training_trajectory( - ax: plt.Axes, + ax: Axes, x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.ndarray, @@ -156,6 +157,8 @@ def plot_training_trajectory( ) if labels: ax.set(xlabel="$x_0$", ylabel="$x_1$") + else: + ax.set(xticks=[], yticks=[]) elif x_train.shape[1] == 3: ax.plot( x_true[:, 0], @@ -187,6 +190,8 @@ def plot_training_trajectory( ) if labels: ax.set(xlabel="$x$", ylabel="$y$", zlabel="$z$") + else: + ax.set(xticks=[], yticks=[], zticks=[]) else: raise ValueError("Can only plot 2d or 3d data.") @@ -226,7 +231,7 @@ def plot_pde_training_data(last_train, last_train_true, smoothed_last_train): def plot_test_sim_data_1d_panel( - axs: Sequence[plt.Axes], + axs: Sequence[Axes], x_test: np.ndarray, x_sim: np.ndarray, t_test: np.ndarray, @@ -240,31 +245,33 @@ def plot_test_sim_data_1d_panel( def _plot_test_sim_data_2d( - axs: Annotated[Sequence[plt.Axes], "len=2"], + axs: Annotated[Sequence[Axes], "len=2"], x_test: np.ndarray, x_sim: np.ndarray, labels: bool = True, ) -> None: axs[0].plot(x_test[:, 0], x_test[:, 1], "k", label="True Trajectory") - if labels: - axs[0].set(xlabel="$x_0$", ylabel="$x_1$") axs[1].plot(x_sim[:, 0], x_sim[:, 1], "r--", label="Simulation") - if labels: - axs[1].set(xlabel="$x_0$", ylabel="$x_1$") + for ax in axs: + if labels: + ax.set(xlabel="$x_0$", ylabel="$x_1$") + else: + ax.set(xticks=[], yticks=[]) def _plot_test_sim_data_3d( - axs: Annotated[Sequence[plt.Axes], "len=3"], + axs: Annotated[Sequence[Axes], "len=3"], x_test: np.ndarray, x_sim: np.ndarray, labels: bool = True, ) -> None: axs[0].plot(x_test[:, 0], x_test[:, 1], x_test[:, 2], "k", label="True Trajectory") - if labels: - axs[0].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") axs[1].plot(x_sim[:, 0], x_sim[:, 1], x_sim[:, 2], "r--", label="Simulation") - if labels: - axs[1].set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") + for ax in axs: + if labels: + ax.set(xlabel="$x_0$", ylabel="$x_1$", zlabel="$x_2$") + else: + ax.set(xticks=[], yticks=[], zticks=[]) def plot_test_trajectories( From 69a65b6986ed877bec5be5d842a0593456e6af54 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 18 Feb 2024 19:53:56 +0000 Subject: [PATCH 11/46] bld: Revert python version 3.9 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1b9ce17..94d1b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "pysindy-experiments" dynamic = ["version"] description = "My general exam experiments" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.9" license = {file = "LICENSE"} keywords = ["Machine Learning", "Science", "Mathematics", "Experiments"] authors = [ From 5369b2db5b508a8add72e4c0bf732bdd31e999d1 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 18 Feb 2024 20:04:41 +0000 Subject: [PATCH 12/46] types(mock): Set dtype arg to get right types --- pyproject.toml | 2 +- src/gen_experiments/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 94d1b14..81edcf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ addopts = '-m "not slow"' markers = ["slow"] [tool.mypy] -files = ["src/gen_experiments/__init__.py"] +files = ["src/gen_experiments/__init__.py", "src/gen_experiments/utils.py"] [[tool.mypy.overrides]] module="auto_ks.*" diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index fbffc64..f25bde1 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -11,7 +11,7 @@ from .utils import SINDyTrialData this_module = importlib.import_module(__name__) -BORING_ARRAY = np.ones((2, 2)) +BORING_ARRAY = np.ones((2, 2), dtype=float) Scores = Mapping[str, float] @@ -51,7 +51,7 @@ def run( # "coefficients": boring_array, "feature_names": ["1"], "input_features": ["x", "y"], - "t_train": np.arange(0, 1, 1), + "t_train": np.arange(0, 1, 1, dtype=float), "x_train": BORING_ARRAY, "x_true": BORING_ARRAY, "smooth_train": BORING_ARRAY, From fa78b9262e4c0760de9e856d90b2d57d2ae9f732 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:29:00 +0000 Subject: [PATCH 13/46] feat(API): export make_model() --- src/gen_experiments/__init__.py | 2 +- src/gen_experiments/utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index f25bde1..978a8e2 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -8,7 +8,7 @@ from pysindy import BaseDifferentiation, FiniteDifference, SINDy # type: ignore from . import gridsearch, odes, pdes -from .utils import SINDyTrialData +from .utils import SINDyTrialData, make_model # noqa: F401 this_module = importlib.import_module(__name__) BORING_ARRAY = np.ones((2, 2), dtype=float) diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 65580ee..49eddae 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -122,6 +122,8 @@ def opt_lookup(kind): return ps.SR3 elif normalized_kind == "miosr": return ps.MIOSR + elif normalized_kind == "trap": + return ps.TrappingSR3 elif normalized_kind == "ensemble": return ps.EnsembleOptimizer else: @@ -207,7 +209,7 @@ def unionize_coeff_matrices( return true_coeff_mat, new_est_coeff, model_features -def _make_model( +def make_model( input_features: list[str], dt: float, diff_params: dict, From bfdc1a50576fd736b3c8c9d1d2e59303929b87c8 Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Tue, 20 Feb 2024 19:42:02 -0800 Subject: [PATCH 14/46] ENH: Introduced better LV parameters --- src/gen_experiments/odes.py | 9 +++++---- src/gen_experiments/pdes.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index 2f3acf3..e45277a 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Callable import matplotlib.pyplot as plt @@ -14,9 +15,9 @@ from .utils import ( FullSINDyTrialData, SINDyTrialData, - _make_model, coeff_metrics, integration_metrics, + make_model, simulate_test_data, unionize_coeff_matrices, ) @@ -59,7 +60,7 @@ def forcing(t, x): p_duff = [0.2, 0.05, 1] -p_lotka = [1, 10] +p_lotka = [5, 1] p_ross = [0.2, 0.2, 5.7] p_hopf = [-0.05, 1, 1] @@ -73,7 +74,7 @@ def forcing(t, x): ], }, "lv": { - "rhsfunc": ps.utils.odes.lotka, + "rhsfunc": partial(ps.utils.odes.lotka, p=p_lotka), "input_features": ["x", "y"], "coeff_true": [ {"x": p_lotka[0], "x y": -p_lotka[1]}, @@ -181,7 +182,7 @@ def run( nonnegative=nonnegative, **sim_params, ) - model = _make_model(input_features, dt, diff_params, feat_params, opt_params) + model = make_model(input_features, dt, diff_params, feat_params, opt_params) model.fit(x_train) coeff_true, coefficients, feature_names = unionize_coeff_matrices(model, coeff_true) diff --git a/src/gen_experiments/pdes.py b/src/gen_experiments/pdes.py index 8cb712b..985ed10 100644 --- a/src/gen_experiments/pdes.py +++ b/src/gen_experiments/pdes.py @@ -7,9 +7,9 @@ from .utils import ( FullSINDyTrialData, SINDyTrialData, - _make_model, coeff_metrics, integration_metrics, + make_model, simulate_test_data, unionize_coeff_matrices, ) @@ -171,7 +171,7 @@ def run( dt=time_args[0], t_end=time_args[1], ) - model = _make_model(input_features, dt, diff_params, feat_params, opt_params) + model = make_model(input_features, dt, diff_params, feat_params, opt_params) model.fit(x_train, t=t_train) coeff_true, coefficients, feature_names = unionize_coeff_matrices(model, coeff_true) From b1b875bd654b5a2cfda1435f5bd444a04fe8cc24 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:30:24 +0000 Subject: [PATCH 15/46] feat(config): Add experiment for gridsearch vs auto kalman --- src/gen_experiments/config.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 2e9137c..4e7ed28 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -421,6 +421,16 @@ def addn(x): diff_series["sg2"], ], ), + "multikalman": SeriesList( + "diff_params", + "Differentiation Method", + [ + diff_series["auto-kalman3"], + diff_series["kalman2"], + diff_series["tv2"], + diff_series["sg2"], + ], + ), } From e4e07112f1c9dce48f5e331365d47224fa324d4f Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 27 Feb 2024 17:48:08 +0000 Subject: [PATCH 16/46] cln(config): Set name of AutoKalman seriesdefs --- src/gen_experiments/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 4e7ed28..47547ba 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -333,19 +333,19 @@ def addn(x): [np.logspace(-4, 0, 5)], ), "auto-kalman": SeriesDef( - "Kalman", + "Auto Kalman", diff_params["kalman"], ["diff_params.alpha", "diff_params.meas_var"], [(None,), (0.1, 0.5, 1, 2, 4, 8)], ), "auto-kalman2": SeriesDef( - "Kalman", + "Auto Kalman", diff_params["kalman"], ["diff_params.alpha", "diff_params.meas_var"], [(None,), (0.01, 0.25, 1, 4, 16, 64)], ), "auto-kalman3": SeriesDef( - "Kalman", + "Auto Kalman", diff_params["kalman"], ["diff_params.alpha"], [(None,)], From b900d594be27ab2e7db0145b89c810d46c334854 Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Mon, 4 Mar 2024 18:09:36 -0800 Subject: [PATCH 17/46] ENH: changed autokalman diff_params to kalman-auto --- src/gen_experiments/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 47547ba..ac12e0a 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -346,7 +346,7 @@ def addn(x): ), "auto-kalman3": SeriesDef( "Auto Kalman", - diff_params["kalman"], + diff_params["kalman-auto"], ["diff_params.alpha"], [(None,)], ), From ba34313e8f270a636f6fc7536dbc037f3a0e66d2 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:40:19 +0000 Subject: [PATCH 18/46] feat(gridsearch): Define GridLocator object --- src/gen_experiments/gridsearch.py | 41 +++++++++++++++++++++++++++++-- src/gen_experiments/utils.py | 2 +- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/gen_experiments/gridsearch.py b/src/gen_experiments/gridsearch.py index e7aa488..81f548f 100644 --- a/src/gen_experiments/gridsearch.py +++ b/src/gen_experiments/gridsearch.py @@ -1,8 +1,19 @@ from copy import copy +from dataclasses import dataclass, field from functools import partial from logging import getLogger from pprint import pformat -from typing import Annotated, Callable, Iterable, Optional, Sequence, TypeVar +from types import EllipsisType as ellipsis +from typing import ( + Annotated, + Any, + Callable, + Collection, + Iterable, + Optional, + Sequence, + TypeVar, +) import matplotlib.pyplot as plt import numpy as np @@ -36,6 +47,32 @@ SkinnySpecs = Optional[tuple[tuple[str, ...], tuple[OtherSliceDef, ...]]] +@dataclass(frozen=True) +class GridLocator: + """A specification of which points in a gridsearch to match. + + Rather than specifying the exact point in the mega-grid of every + varied axis, specify by result, e.g "all of the points from the + Kalman series that had the best mean squared error as noise was + varied. + + Args: + metric: The metric in which to find results. An ellipsis means "any metrics" + keep_axis: The grid-varied parameter in which to find results, or a tuple of + that axis and position along that axis. To search a particular value of + that parameter, use the param_match kwarg. An ellipsis means "any axis" + param_match: A collection of dictionaries to match parameter values represented + by points in the gridsearch. Dictionary equality is checked for every + non-callable value; for callable values, it is applied to the grid + parameters and must return a boolean. Logical OR is applied across the + collection + """ + + metric: str | ellipsis = field(default=...) + keep_axis: str | tuple[str, int] | ellipsis = field(default=...) + param_match: Collection[dict[str, Any]] = field(default=()) + + def run( seed: int, group: str, @@ -44,7 +81,7 @@ def run( grid_decisions: Sequence[str], other_params: dict, series_params: Optional[SeriesList] = None, - metrics: Optional[Sequence[str]] = None, + metrics: Sequence[str] = (), plot_prefs: _PlotPrefs = _PlotPrefs(True, False, ()), skinny_specs: SkinnySpecs = None, ) -> GridsearchResultDetails: diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 49eddae..76d3408 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -446,7 +446,7 @@ def _amax_to_full_inds( Args: amax_inds: selection statemtent of which argmaxes to return. amax_arrays: arrays of indexes to full gridsearch that are responsible for - the computed max values. First level of nesting reflects series(?), second + the computed max values. First level of nesting reflects series, second level reflects which grid grid axis. Returns: all indexers to full gridsearch that are requested by amax_inds From e34149584eef945aa53226e57f8d7842d51863ac Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:49:22 +0000 Subject: [PATCH 19/46] cln: Extract gridsearch functionality (mostly) to subpackage --- src/gen_experiments/__init__.py | 5 +- src/gen_experiments/config.py | 3 +- src/gen_experiments/data.py | 5 +- src/gen_experiments/debug.py | 25 -- .../{gridsearch.py => gridsearch/__init__.py} | 254 ++++++++++-- src/gen_experiments/gridsearch/typing.py | 201 ++++++++++ src/gen_experiments/plotting.py | 31 +- src/gen_experiments/typing.py | 10 + src/gen_experiments/utils.py | 377 +----------------- tests/test_all.py | 37 +- 10 files changed, 467 insertions(+), 481 deletions(-) delete mode 100644 src/gen_experiments/debug.py rename src/gen_experiments/{gridsearch.py => gridsearch/__init__.py} (64%) create mode 100644 src/gen_experiments/gridsearch/typing.py create mode 100644 src/gen_experiments/typing.py diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index 978a8e2..61f64fa 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -5,9 +5,10 @@ import numpy as np from numpy.typing import NDArray -from pysindy import BaseDifferentiation, FiniteDifference, SINDy # type: ignore +from pysindy import BaseDifferentiation, FiniteDifference, SINDy -from . import gridsearch, odes, pdes +from . import gridsearch # type: ignore +from . import odes, pdes from .utils import SINDyTrialData, make_model # noqa: F401 this_module = importlib.import_module(__name__) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 47547ba..6a5654c 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -4,8 +4,9 @@ import pysindy as ps from gen_experiments.data import _signal_avg_power +from gen_experiments.gridsearch.typing import NestedDict, SeriesDef, SeriesList from gen_experiments.plotting import _PlotPrefs -from gen_experiments.utils import FullSINDyTrialData, NestedDict, SeriesDef, SeriesList +from gen_experiments.utils import FullSINDyTrialData T = TypeVar("T") U = TypeVar("U") diff --git a/src/gen_experiments/data.py b/src/gen_experiments/data.py index 0d651f2..0fc9760 100644 --- a/src/gen_experiments/data.py +++ b/src/gen_experiments/data.py @@ -7,7 +7,8 @@ import numpy as np import scipy -from gen_experiments.utils import Float1D, Float2D, GridsearchResultDetails +from gen_experiments.gridsearch.typing import GridsearchResultDetails +from gen_experiments.utils import Float1D, Float2D INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12} TRIALS_FOLDER = Path(__file__).parent.absolute() / "trials" @@ -56,7 +57,7 @@ def gen_data( noise_abs = 0.1 rng = np.random.default_rng(seed) if x0_center is None: - x0_center = np.zeros((n_coord)) + x0_center = np.zeros((n_coord), dtype=np.float_) t_train = np.arange(0, t_end, dt, dtype=np.float_) t_train_span = (t_train[0], t_train[-1]) if nonnegative: diff --git a/src/gen_experiments/debug.py b/src/gen_experiments/debug.py deleted file mode 100644 index a1193ca..0000000 --- a/src/gen_experiments/debug.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Annotated, Generic, TypedDict, TypeVar - -import numpy as np -from numpy.typing import DTypeLike, NBitBase, NDArray - -# T = TypeVar("T") - -# class Foo[T]: -# items: list[T] - -# def __init__(self, thing: T): -# self.items = [thing, thing] - -# Bar = - - -T = TypeVar("T", bound=np.generic) -Foo = NDArray[T] -Bar = Annotated[NDArray, "foobar"] - -lil_foo = NDArray[np.void] - - -def baz(qux: Foo[np.void]): - pass diff --git a/src/gen_experiments/gridsearch.py b/src/gen_experiments/gridsearch/__init__.py similarity index 64% rename from src/gen_experiments/gridsearch.py rename to src/gen_experiments/gridsearch/__init__.py index 81f548f..89e5d3b 100644 --- a/src/gen_experiments/gridsearch.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -1,5 +1,5 @@ +from collections.abc import Iterable from copy import copy -from dataclasses import dataclass, field from functools import partial from logging import getLogger from pprint import pformat @@ -9,14 +9,15 @@ Any, Callable, Collection, - Iterable, Optional, Sequence, TypeVar, + cast, ) import matplotlib.pyplot as plt import numpy as np +from matplotlib.axes import Axes from numpy.typing import DTypeLike, NDArray from scipy.stats import kstest @@ -24,18 +25,17 @@ from gen_experiments import config from gen_experiments.odes import plot_ode_panel from gen_experiments.plotting import _PlotPrefs -from gen_experiments.utils import ( +from gen_experiments.typing import FloatND +from gen_experiments.utils import simulate_test_data + +from .typing import ( + ExpResult, GridsearchResult, GridsearchResultDetails, NestedDict, - SavedData, + SavedGridPoint, SeriesDef, SeriesList, - SINDyTrialData, - _amax_to_full_inds, - _argopt, - _grid_locator_match, - simulate_test_data, ) pformat = partial(pformat, indent=4, sort_dicts=True) @@ -47,30 +47,99 @@ SkinnySpecs = Optional[tuple[tuple[str, ...], tuple[OtherSliceDef, ...]]] -@dataclass(frozen=True) -class GridLocator: - """A specification of which points in a gridsearch to match. +def _amax_to_full_inds( + amax_inds: Collection[tuple[int | slice, int] | ellipsis] | ellipsis, + amax_arrays: list[list[GridsearchResult[np.void]]], +) -> set[tuple[int, ...]]: + """Find full indexers to selected elements of argmax arrays + + Args: + amax_inds: selection statemtent of which argmaxes to return. + amax_arrays: arrays of indexes to full gridsearch that are responsible for + the computed max values. First level of nesting reflects series, second + level reflects which grid grid axis. + Returns: + all indexers to full gridsearch that are requested by amax_inds + """ + + def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: + return tuple(int(el) for el in cast(Iterable, tuple_like)) + + if amax_inds is ...: # grab each element from arrays in list of lists of arrays + return { + np_to_primitive(el) + for ar_list in amax_arrays + for arr in ar_list + for el in arr.flatten() + } + all_inds = set() + for plot_axis_results in [el for series in amax_arrays for el in series]: + for ind in amax_inds: + if ind is ...: # grab each element from arrays in list of lists of arrays + all_inds |= { + np_to_primitive(el) + for ar_list in amax_arrays + for arr in ar_list + for el in arr.flatten() + } + elif isinstance(ind[0], int): + all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))} + else: # ind[0] is slice(None) + all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]} + return all_inds + + +_EqTester = TypeVar("_EqTester") + + +def _param_normalize(val: _EqTester) -> _EqTester | str: + if type(val).__eq__ == object.__eq__: + return repr(val) + else: + return val + + +def _grid_locator_match( + exp_params: dict[str, Any], + exp_ind: tuple[int, ...], + param_spec: Collection[dict[str, Any]], + ind_spec: Collection[tuple[int, ...]], +) -> bool: + """Determine whether experimental parameters match a specification + + Logical clause applied is: - Rather than specifying the exact point in the mega-grid of every - varied axis, specify by result, e.g "all of the points from the - Kalman series that had the best mean squared error as noise was - varied. + OR((exp_params MATCHES params for params in param_spec)) + AND + OR((exp_ind MATCHES ind for ind in ind_spec)) + Treats OR of an empty collection as falsy Args: - metric: The metric in which to find results. An ellipsis means "any metrics" - keep_axis: The grid-varied parameter in which to find results, or a tuple of - that axis and position along that axis. To search a particular value of - that parameter, use the param_match kwarg. An ellipsis means "any axis" - param_match: A collection of dictionaries to match parameter values represented - by points in the gridsearch. Dictionary equality is checked for every - non-callable value; for callable values, it is applied to the grid - parameters and must return a boolean. Logical OR is applied across the - collection + exp_params: the experiment parameters to evaluate + exp_ind: the experiemnt's full-size grid index to evaluate + param_spec: the criteria for matching exp_params + ind_spec: the criteria for matching exp_ind """ + found_match = False + for params_or in param_spec: + params_or = {k: _param_normalize(v) for k, v in params_or.items()} - metric: str | ellipsis = field(default=...) - keep_axis: str | tuple[str, int] | ellipsis = field(default=...) - param_match: Collection[dict[str, Any]] = field(default=()) + try: + if all( + _param_normalize(exp_params[param]) == value + for param, value in params_or.items() + ): + found_match = True + break + except KeyError: + pass + for ind_or in ind_spec: + # exp_ind doesn't include metric, so skip first metric + if _index_in(exp_ind, ind_or[1:]): + break + else: + return False + return found_match def run( @@ -78,7 +147,7 @@ def run( group: str, grid_params: list[str], grid_vals: list[Sequence], - grid_decisions: Sequence[str], + grid_decisions: list[str], other_params: dict, series_params: Optional[SeriesList] = None, metrics: Sequence[str] = (), @@ -118,8 +187,8 @@ def run( metric_ordering = [base_ex.metric_ordering[metric] for metric in metrics] n_plotparams = len([decide for decide in grid_decisions if decide == "plot"]) series_searches: list[tuple[list[GridsearchResult], list[GridsearchResult]]] = [] - intermediate_data: list[SavedData] = [] - plot_data: list[SavedData] = [] + intermediate_data: list[SavedGridPoint] = [] + plot_data: list[SavedGridPoint] = [] if base_group is not None: other_params["group"] = base_group for s_counter, series_data in enumerate(series_params.series_list): @@ -145,14 +214,12 @@ def run( for ind_counter, ind in enumerate(gridpoint_selector): print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r") new_seed = rng.integers(1000) - param_updates = {} for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): - param_updates[key] = val_list[axis_ind] - curr_other_params.update(param_updates) + curr_other_params[key] = val_list[axis_ind] curr_results, grid_data = base_ex.run( new_seed, **curr_other_params, display=False, return_all=True ) - grid_data: SINDyTrialData + grid_data: ExpResult intermediate_data.append( {"params": curr_other_params.flatten(), "pind": ind, "data": grid_data} ) @@ -205,7 +272,6 @@ def run( series_searches, (ser.name for ser in series_params.series_list) ): plot( - fig, subplots, metrics, grid_params, @@ -244,8 +310,7 @@ def run( def plot( - fig: plt.Figure, - subplots: Sequence[plt.Axes], + subplots: NDArray[Annotated[np.void, "Axes"]], metrics: Sequence[str], grid_params: Sequence[str], grid_vals: Sequence[Sequence[float] | np.ndarray], @@ -253,11 +318,13 @@ def plot( name: str, legends: bool, ): + if len(metrics) == 0: + raise ValueError("Nothing to plot") for m_ind_row, m_name in enumerate(metrics): for col, (param_name, x_ticks, param_search) in enumerate( zip(grid_params, grid_vals, grid_searches) ): - ax = subplots[m_ind_row, col] + ax = cast(Axes, subplots[m_ind_row, col]) ax.plot(x_ticks, param_search[m_ind_row], label=name) x_ticks = np.array(x_ticks) if m_name in ("coeff_mse", "coeff_mae"): @@ -278,16 +345,54 @@ def plot( if col == 0: ax.set_ylabel(f"{m_name}") if legends: - ax.legend() + ax.legend() # type: ignore T = TypeVar("T", bound=np.generic) +def _argopt( + arr: FloatND, axis: Optional[int | tuple[int, ...]] = None, opt: str = "max" +) -> NDArray[np.void]: + """Calculate the argmax/min, but accept tuple axis. + + Ignores NaN values + + Args: + arr: an array to search + axis: The axis or axes to search through for the argmax/argmin. + opt: One of {"max", "min"} + + Returns: + array of indices for the argopt. If m = arr.ndim and n = len(axis), + the final result will be an array of ndim = m-n with elements being + tuples of length m + """ + dtype: DTypeLike = [(f"f{axind}", "i") for axind in range(arr.ndim)] + if axis is None: + axis = () + axis = (axis,) if isinstance(axis, int) else axis + keep_axes = tuple(sorted(set(range(arr.ndim)) - set(axis))) + keep_shape = tuple(arr.shape[ax] for ax in keep_axes) + result = np.empty(keep_shape, dtype=dtype) + optfun = np.nanargmax if opt == "max" else np.nanargmin + for slise in np.ndindex(keep_shape): + sub_arr = arr + # since we shrink shape, we need to chop of axes from the end + for ind, ax in zip(reversed(slise), reversed(keep_axes)): + sub_arr = np.take(sub_arr, ind, ax) + subind_max = np.unravel_index(optfun(sub_arr), sub_arr.shape) + fullind_max = np.empty((arr.ndim), int) + fullind_max[np.array(keep_axes, int)] = slise + fullind_max[np.array(axis, int)] = subind_max + result[slise] = tuple(fullind_max) + return result + + def _marginalize_grid_views( grid_decisions: Iterable[str], results: Annotated[NDArray[T], "shape (n_metrics, *n_gridsearch_values)"], - max_or_min: Sequence[str] = None, + max_or_min: Sequence[str], ) -> tuple[list[GridsearchResult[T]], list[GridsearchResult]]: """Marginalize unnecessary dimensions by taking max across axes. @@ -303,7 +408,7 @@ def _marginalize_grid_views( a list of the metric optima for each plottable grid decision, and a list of the flattened argoptima. """ - arg_dtype: DTypeLike = ",".join(results.ndim * "i") + arg_dtype = np.dtype(",".join(results.ndim * "i")) plot_param_inds = [ind for ind, val in enumerate(grid_decisions) if val == "plot"] grid_searches = [] args_maxes = [] @@ -315,7 +420,11 @@ def _marginalize_grid_views( ) sub_arrs = [] for m_ind, (result, opt) in enumerate(zip(results, max_or_min)): - pad_m_ind = np.vectorize(lambda tp: np.void((m_ind, *tp), dtype=arg_dtype)) + + def _metric_pad(tp: tuple[int, ...]) -> np.void: + return np.void((m_ind, *tp), dtype=arg_dtype) + + pad_m_ind = np.vectorize(_metric_pad) arg_max = pad_m_ind(_argopt(result, reduce_axes, opt)) sub_arrs.append(arg_max) @@ -326,7 +435,7 @@ def _marginalize_grid_views( def _ndindex_skinny( - shape: tuple[int], + shape: tuple[int, ...], thin_axes: Optional[Sequence[int]] = None, thin_slices: Optional[Sequence[OtherSliceDef]] = None, ): @@ -360,6 +469,7 @@ def _ndindex_skinny( n_thin = len(thin_axes) thin_slices = n_thin * ((n_thin - 1) * (0,),) full_indexes = np.ndindex(shape) + thin_slices = cast(Sequence[OtherSliceDef], thin_slices) def ind_checker(multi_index): """Check if a multi_index meets thin index criteria""" @@ -415,3 +525,59 @@ def _curr_skinny_specs( ) where_others.append(new_criteria) return skinny_param_inds, tuple(where_others) + + +def strict_find_grid_match( + results: GridsearchResultDetails, + *, + params: Optional[dict[str, Any]] = None, + ind_spec: Optional[tuple[int | slice, int] | ellipsis] = None, +) -> ExpResult: + if params is None: + params = {} + if ind_spec is None: + ind_spec = ... + matches = [] + amax_arrays = [ + [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] + for _, single_series_all_axes in results["series_data"].items() + ] + full_inds = _amax_to_full_inds((ind_spec,), amax_arrays) + + for trajectory in results["plot_data"]: + if _grid_locator_match( + trajectory["params"], trajectory["pind"], (params,), full_inds + ): + matches.append(trajectory) + + if len(matches) > 1: + raise ValueError("Specification is nonunique; matched multiple results") + if len(matches) == 0: + raise ValueError("Could not find a match") + return matches[0]["data"] + + +def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> bool: + """Determine whether base indexing tuple will match given numpy index""" + if len(base) > len(tgt): + return False + curr_ax = 0 + for ax, ind in enumerate(tgt): + if isinstance(ind, int): + try: + if ind != base[curr_ax]: + return False + except IndexError: + return False + elif isinstance(ind, slice): + if not (ind.start is None and ind.stop is None and ind.step is None): + raise ValueError("Only slices allowed are `slice(None)`") + elif ind is ...: + base_ind_remaining = len(base) - curr_ax + tgt_ind_remaining = len(tgt) - ax + # ellipsis can take 0 or more spots + curr_ax += max(base_ind_remaining - tgt_ind_remaining, -1) + curr_ax += 1 + if curr_ax == len(base): + return True + return False diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py new file mode 100644 index 0000000..1aa2c98 --- /dev/null +++ b/src/gen_experiments/gridsearch/typing.py @@ -0,0 +1,201 @@ +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass, field +from types import EllipsisType as ellipsis +from typing import Annotated, Any, Collection, Optional, Sequence, TypedDict, TypeVar + +import numpy as np +from numpy.typing import NDArray + + +@dataclass(frozen=True) +class GridLocator: + """A specification of which points in a gridsearch to match. + + Rather than specifying the exact point in the mega-grid of every + varied axis, specify by result, e.g "all of the points from the + Kalman series that had the best mean squared error as noise was + varied. + + Args: + metric: The metric in which to find results. An ellipsis means "any metrics" + keep_axis: The grid-varied parameter in which to find results, or a tuple of + that axis and position along that axis. To search a particular value of + that parameter, use the param_match kwarg. An ellipsis means "any axis" + param_match: A collection of dictionaries to match parameter values represented + by points in the gridsearch. Dictionary equality is checked for every + non-callable value; for callable values, it is applied to the grid + parameters and must return a boolean. Logical OR is applied across the + collection + """ + + metric: str | ellipsis = field(default=...) + keep_axis: str | tuple[str, int] | ellipsis = field(default=...) + param_match: Collection[dict[str, Any]] = field(default=()) + + +T = TypeVar("T", bound=np.generic) +GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] +SeriesData = Annotated[ + list[ + tuple[ + Annotated[GridsearchResult, "metrics"], + Annotated[GridsearchResult[np.void], "arg_opts"], + ] + ], + "len=n_grid_axes", +] + +ExpResult = dict[str, Any] + + +class SavedGridPoint(TypedDict): + params: dict + pind: tuple[int] + data: ExpResult + + +class GridsearchResultDetails(TypedDict): + system: str + plot_data: list[SavedGridPoint] + series_data: dict[str, SeriesData] + metrics: list[str] + grid_params: list[str] + grid_vals: list[Sequence] + grid_axes: dict[str, Collection[float]] + main: float + + +@dataclass +class SeriesDef: + """The details of constructing the ragged axes of a grid search. + + The concept of a SeriesDef refers to a slice along a single axis of + a grid search in conjunction with another axis (or axes) + whose size or meaning differs along different slices. + + Attributes: + name: The name of the slice, as a label for printing + static_param: the constant parameter to this slice. Then key is + the name of the parameter, as understood by the experiment + Conceptually, the key serves as an index of this slice in + the gridsearch. + grid_params: the keys of the parameters in the experiment that + vary along jagged axis for this slice + grid_vals: the values of the parameters in the experiment that + vary along jagged axis for this slice + + Example: + + truck_wheels = SeriesDef( + "Truck", + {"vehicle": "flatbed_truck"}, + ["vehicle.n_wheels"], + [[10, 18]] + ) + + """ + + name: str + static_param: dict + grid_params: list[str] + grid_vals: list[Iterable] + + +@dataclass +class SeriesList: + """Specify the ragged slices of a grid search. + + As an example, consider a grid search of miles per gallon for + different vehicles, in different routes, with different tires. + Since different tires fit on different vehicles, the tire axis would + be ragged, varying along the vehicle axis. + + Truck = SeriesDef("trucks") + + Attributes: + param_name: the key of the parameter in the experiment that + varies along the series axis. + print_name: the print name of the parameter in the experiment + that varies along the series axis. + series_list: Each element of the series axis + + Example: + + truck_wheels = SeriesDef( + "Truck", + {"vehicle": "flatbed_truck"}, + ["vehicle.n_wheels"], + [[10, 18]] + ) + bike_tires = SeriesDef( + "Bike", + {"vehicle": "bicycle"}, + ["vehicle.tires"], + [["gravel_tires", "road_tires"]] + ) + VehicleOptions = SeriesList( + "vehicle", + "Vehicle Types", + [truck_wheels, bike_tires] + ) + + """ + + param_name: Optional[str] + print_name: Optional[str] + series_list: list[SeriesDef] + + +class NestedDict(defaultdict): + """A dictionary that splits all keys by ".", creating a sub-dict. + + Args: see superclass + + Example: + + >>> foo = NestedDict("a.b"=1) + >>> foo["a.c"] = 2 + >>> foo["a"]["b"] + 1 + """ + + def __missing__(self, key): + try: + prefix, subkey = key.split(".", 1) + except ValueError: + raise KeyError(key) + return self[prefix][subkey] + + def __setitem__(self, key, value): + if "." in key: + prefix, suffix = key.split(".", 1) + if self.get(prefix) is None: + self[prefix] = NestedDict() + return self[prefix].__setitem__(suffix, value) + else: + return super().__setitem__(key, value) + + def update(self, other: dict): # type: ignore + try: + for k, v in other.items(): + self.__setitem__(k, v) + except: # noqa: E722 + super().update(other) + + def flatten(self): + """Flattens a nested dictionary without mutating. Returns new dict""" + + def _flatten(nested_d: dict) -> dict: + new = {} + for key, value in nested_d.items(): + if not isinstance(key, str): + raise TypeError("Only string keys allowed in flattening") + if not isinstance(value, dict): + new[key] = value + continue + for sub_key, sub_value in _flatten(value).items(): + new[key + "." + sub_key] = sub_value + return new + + return _flatten(self) diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 0aa09bd..461da5a 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from types import EllipsisType as ellipsis -from typing import Annotated, Callable, Collection, Literal, Mapping, Sequence +from typing import Annotated, Any, Callable, Collection, Literal, Sequence import matplotlib.pyplot as plt import numpy as np @@ -29,8 +29,8 @@ class _PlotPrefs: plot: bool = True rel_noise: Literal[False] | Callable = False - grid_params_match: Collection[dict] = field(default_factory=lambda: ()) - grid_ind_match: Collection[tuple[int | slice, int]] | ellipsis = field( + grid_params_match: Collection[dict[str, Any]] = field(default_factory=lambda: ()) + grid_ind_match: Collection[tuple[str | slice, int]] | ellipsis = field( default_factory=lambda: ... ) @@ -40,9 +40,9 @@ def __bool__(self): def plot_coefficients( coefficients: Annotated[np.ndarray, "(n_coord, n_features)"], - input_features: Sequence[str] = None, - feature_names: Sequence[str] = None, - ax: bool = None, + input_features: Sequence[str], + feature_names: Sequence[str], + ax: Axes, **heatmap_kws, ): def detex(input: str) -> str: @@ -61,9 +61,6 @@ def detex(input: str) -> str: feature_names = [f"f{k}" for k in range(coefficients.shape[1])] with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}): - if ax is None: - fig, ax = plt.subplots(1, 1) - heatmap_args = { "xticklabels": input_features, "yticklabels": feature_names, @@ -85,8 +82,8 @@ def detex(input: str) -> str: def compare_coefficient_plots( coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"], coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"], - input_features: Sequence[str] = None, - feature_names: Sequence[str] = None, + input_features: Sequence[str], + feature_names: Sequence[str], ): """Create plots of true and estimated coefficients.""" n_cols = len(coefficients_est) @@ -203,6 +200,8 @@ def plot_training_data(x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.nda ax0 = fig.add_subplot(1, 2, 1) elif x_train.shape[-1] == 3: ax0 = fig.add_subplot(1, 2, 1, projection="3d") + else: + raise ValueError("Too many or too few coordinates to plot") plot_training_trajectory(ax0, x_train, x_true, x_smooth) ax0.legend() ax0.set(title="Training data") @@ -276,7 +275,7 @@ def _plot_test_sim_data_3d( def plot_test_trajectories( x_test: np.ndarray, x_sim: np.ndarray, t_test: np.ndarray, t_sim: np.ndarray -) -> Mapping[str, np.ndarray]: +) -> None: """Plot a test trajectory Args: @@ -288,19 +287,17 @@ def plot_test_trajectories( A dict with two keys, "t_sim" (the simulation times) and "x_sim" (the simulated trajectory) """ - fig, axs = plt.subplots(x_test.shape[1], 1, sharex=True, figsize=(7, 9)) + _, axs = plt.subplots(x_test.shape[1], 1, sharex=True, figsize=(7, 9)) plt.suptitle("Test Trajectories by Dimension") plot_test_sim_data_1d_panel(axs, x_test, x_sim, t_test, t_sim) axs[-1].legend() plt.suptitle("Full Test Trajectories") if x_test.shape[1] == 2: - fig, axs = plt.subplots(1, 2, figsize=(10, 4.5)) + _, axs = plt.subplots(1, 2, figsize=(10, 4.5)) _plot_test_sim_data_2d(axs, x_test, x_sim) elif x_test.shape[1] == 3: - fig, axs = plt.subplots( - 1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"} - ) + _, axs = plt.subplots(1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"}) _plot_test_sim_data_3d(axs, x_test, x_sim) else: raise ValueError("Can only plot 2d or 3d data.") diff --git a/src/gen_experiments/typing.py b/src/gen_experiments/typing.py new file mode 100644 index 0000000..e797377 --- /dev/null +++ b/src/gen_experiments/typing.py @@ -0,0 +1,10 @@ +from typing import TypeVar + +import numpy as np +from numpy.typing import NBitBase + +NpFlt = np.dtype[np.floating[NBitBase]] +Float1D = np.ndarray[tuple[int], NpFlt] +Float2D = np.ndarray[tuple[int, int], NpFlt] +Shape = TypeVar("Shape", bound=tuple[int, ...]) +FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 76d3408..800b178 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -1,18 +1,5 @@ -from collections import defaultdict -from collections.abc import Iterable -from dataclasses import dataclass from itertools import chain -from types import EllipsisType as ellipsis -from typing import ( - Annotated, - Any, - Collection, - Optional, - Sequence, - TypedDict, - TypeVar, - cast, -) +from typing import Annotated, TypedDict, cast from warnings import warn import auto_ks as aks @@ -21,13 +8,9 @@ import pysindy as ps import sklearn import sklearn.metrics -from numpy.typing import DTypeLike, NBitBase, NDArray +from numpy.typing import NDArray -NpFlt = np.dtype[np.floating[NBitBase]] -Float1D = np.ndarray[tuple[int], NpFlt] -Float2D = np.ndarray[tuple[int, int], NpFlt] -Shape = TypeVar("Shape", bound=tuple[int, ...]) -FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] +from .typing import Float1D, Float2D, FloatND class SINDyTrialData(TypedDict): @@ -56,36 +39,6 @@ class FullSINDyTrialData(SINDyTrialData): x_sim: np.ndarray -class SavedData(TypedDict): - params: dict - pind: tuple[int] - data: SINDyTrialData | FullSINDyTrialData - - -T = TypeVar("T", bound=np.generic) -GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] -SeriesData = Annotated[ - list[ - tuple[ - Annotated[GridsearchResult, "metrics"], - Annotated[GridsearchResult[np.void], "arg_opts"], - ] - ], - "len=n_grid_axes", -] - - -class GridsearchResultDetails(TypedDict): - system: str - plot_data: list[SavedData] - series_data: dict[str, SeriesData] - metrics: list[str] - grid_params: list[str] - grid_vals: list[Sequence] - grid_axes: dict[str, Collection[float]] - main: float - - def diff_lookup(kind): normalized_kind = kind.lower().replace(" ", "") if normalized_kind == "finitedifference": @@ -264,141 +217,6 @@ def simulate_test_data(model: ps.SINDy, dt: float, x_test: Float2D) -> SINDyTria return {"t_sim": t_sim, "x_sim": x_sim, "t_test": t_test} -@dataclass -class SeriesDef: - """The details of constructing the ragged axes of a grid search. - - The concept of a SeriesDef refers to a slice along a single axis of - a grid search in conjunction with another axis (or axes) - whose size or meaning differs along different slices. - - Attributes: - name: The name of the slice, as a label for printing - static_param: the constant parameter to this slice. Then key is - the name of the parameter, as understood by the experiment - Conceptually, the key serves as an index of this slice in - the gridsearch. - grid_params: the keys of the parameters in the experiment that - vary along jagged axis for this slice - grid_vals: the values of the parameters in the experiment that - vary along jagged axis for this slice - - Example: - - truck_wheels = SeriesDef( - "Truck", - {"vehicle": "flatbed_truck"}, - ["vehicle.n_wheels"], - [[10, 18]] - ) - - """ - - name: str - static_param: dict - grid_params: Optional[Sequence[str]] - grid_vals: Optional[list[Sequence]] - - -@dataclass -class SeriesList: - """Specify the ragged slices of a grid search. - - As an example, consider a grid search of miles per gallon for - different vehicles, in different routes, with different tires. - Since different tires fit on different vehicles, the tire axis would - be ragged, varying along the vehicle axis. - - Truck = SeriesDef("trucks") - - Attributes: - param_name: the key of the parameter in the experiment that - varies along the series axis. - print_name: the print name of the parameter in the experiment - that varies along the series axis. - series_list: Each element of the series axis - - Example: - - truck_wheels = SeriesDef( - "Truck", - {"vehicle": "flatbed_truck"}, - ["vehicle.n_wheels"], - [[10, 18]] - ) - bike_tires = SeriesDef( - "Bike", - {"vehicle": "bicycle"}, - ["vehicle.tires"], - [["gravel_tires", "road_tires"]] - ) - VehicleOptions = SeriesList( - "vehicle", - "Vehicle Types", - [truck_wheels, bike_tires] - ) - - """ - - param_name: Optional[str] - print_name: Optional[str] - series_list: list[SeriesDef] - - -class NestedDict(defaultdict): - """A dictionary that splits all keys by ".", creating a sub-dict. - - Args: see superclass - - Example: - - >>> foo = NestedDict("a.b"=1) - >>> foo["a.c"] = 2 - >>> foo["a"]["b"] - 1 - """ - - def __missing__(self, key): - try: - prefix, subkey = key.split(".", 1) - except ValueError: - raise KeyError(key) - return self[prefix][subkey] - - def __setitem__(self, key, value): - if "." in key: - prefix, suffix = key.split(".", 1) - if self.get(prefix) is None: - self[prefix] = NestedDict() - return self[prefix].__setitem__(suffix, value) - else: - return super().__setitem__(key, value) - - def update(self, other: dict): # type: ignore - try: - for k, v in other.items(): - self.__setitem__(k, v) - except: # noqa: E722 - super().update(other) - - def flatten(self): - """Flattens a nested dictionary without mutating. Returns new dict""" - - def _flatten(nested_d: dict) -> dict: - new = {} - for key, value in nested_d.items(): - if not isinstance(key, str): - raise TypeError("Only string keys allowed in flattening") - if not isinstance(value, dict): - new[key] = value - continue - for sub_key, sub_value in _flatten(value).items(): - new[key + "." + sub_key] = sub_value - return new - - return _flatten(self) - - def kalman_generalized_cv( times: np.ndarray, measurements: np.ndarray, alpha0: float = 1, detail=False ): @@ -435,192 +253,3 @@ def proj(curr_params, t): est_Q = np.linalg.inv(params.W_neg_sqrt @ params.W_neg_sqrt.T) est_alpha = 1 / (est_Q / Qi).mean() return est_alpha - - -def _amax_to_full_inds( - amax_inds: Collection[tuple[int | slice, int] | ellipsis] | ellipsis, - amax_arrays: list[list[GridsearchResult[np.void]]], -) -> set[tuple[int, ...]]: - """Find full indexers to selected elements of argmax arrays - - Args: - amax_inds: selection statemtent of which argmaxes to return. - amax_arrays: arrays of indexes to full gridsearch that are responsible for - the computed max values. First level of nesting reflects series, second - level reflects which grid grid axis. - Returns: - all indexers to full gridsearch that are requested by amax_inds - """ - - def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: - return tuple(int(el) for el in cast(Iterable, tuple_like)) - - if amax_inds is ...: # grab each element from arrays in list of lists of arrays - return { - np_to_primitive(el) - for ar_list in amax_arrays - for arr in ar_list - for el in arr.flatten() - } - all_inds = set() - for plot_axis_results in [el for series in amax_arrays for el in series]: - for ind in amax_inds: - if ind is ...: # grab each element from arrays in list of lists of arrays - all_inds |= { - np_to_primitive(el) - for ar_list in amax_arrays - for arr in ar_list - for el in arr.flatten() - } - elif isinstance(ind[0], int): - all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))} - else: # ind[0] is slice(None) - all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]} - return all_inds - - -def _argopt( - arr: FloatND, axis: Optional[int | tuple[int, ...]] = None, opt: str = "max" -) -> NDArray[np.void]: - """Calculate the argmax/min, but accept tuple axis. - - Ignores NaN values - - Args: - arr: an array to search - axis: The axis or axes to search through for the argmax/argmin. - opt: One of {"max", "min"} - - Returns: - array of indices for the argopt. If m = arr.ndim and n = len(axis), - the final result will be an array of ndim = m-n with elements being - tuples of length m - """ - dtype: DTypeLike = [(f"f{axind}", "i") for axind in range(arr.ndim)] - if axis is None: - axis = () - axis = (axis,) if isinstance(axis, int) else axis - keep_axes = tuple(sorted(set(range(arr.ndim)) - set(axis))) - keep_shape = tuple(arr.shape[ax] for ax in keep_axes) - result = np.empty(keep_shape, dtype=dtype) - optfun = np.nanargmax if opt == "max" else np.nanargmin - for slise in np.ndindex(keep_shape): - sub_arr = arr - # since we shrink shape, we need to chop of axes from the end - for ind, ax in zip(reversed(slise), reversed(keep_axes)): - sub_arr = np.take(sub_arr, ind, ax) - subind_max = np.unravel_index(optfun(sub_arr), sub_arr.shape) - fullind_max = np.empty((arr.ndim), int) - fullind_max[np.array(keep_axes, int)] = slise - fullind_max[np.array(axis, int)] = subind_max - result[slise] = tuple(fullind_max) - return result - - -def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> bool: - """Determine whether base indexing tuple will match given numpy index""" - if len(base) > len(tgt): - return False - curr_ax = 0 - for ax, ind in enumerate(tgt): - if isinstance(ind, int): - try: - if ind != base[curr_ax]: - return False - except IndexError: - return False - elif isinstance(ind, slice): - if not (ind.start is None and ind.stop is None and ind.step is None): - raise ValueError("Only slices allowed are `slice(None)`") - elif ind is ...: - base_ind_remaining = len(base) - curr_ax - tgt_ind_remaining = len(tgt) - ax - # ellipsis can take 0 or more spots - curr_ax += max(base_ind_remaining - tgt_ind_remaining, -1) - curr_ax += 1 - if curr_ax == len(base): - return True - return False - - -def _grid_locator_match( - exp_params: dict[str, Any], - exp_ind: tuple[int, ...], - param_spec: Collection[dict[str, Any]], - ind_spec: Collection[tuple[int, ...]], -) -> bool: - """Determine whether experimental parameters match a specification - - Logical clause applied is: - - OR((exp_params MATCHES params for params in param_spec)) - AND - OR((exp_ind MATCHES ind for ind in ind_spec)) - - Treats OR of an empty collection as falsy - Args: - exp_params: the experiment parameters to evaluate - exp_ind: the experiemnt's full-size grid index to evaluate - param_spec: the criteria for matching exp_params - ind_spec: the criteria for matching exp_ind - """ - found_match = False - for params_or in param_spec: - params_or = {k: _param_normalize(v) for k, v in params_or.items()} - - try: - if all( - _param_normalize(exp_params[param]) == value - for param, value in params_or.items() - ): - found_match = True - break - except KeyError: - pass - for ind_or in ind_spec: - # exp_ind doesn't include metric, so skip first metric - if _index_in(exp_ind, ind_or[1:]): - break - else: - return False - return found_match - - -def strict_find_grid_match( - results: GridsearchResultDetails, - *, - params: Optional[dict[str, Any]] = None, - ind_spec: Optional[tuple[int | slice, int] | ellipsis] = None, -) -> SINDyTrialData: - if params is None: - params = {} - if ind_spec is None: - ind_spec = ... - matches = [] - amax_arrays = [ - [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] - for _, single_series_all_axes in results["series_data"].items() - ] - full_inds = _amax_to_full_inds((ind_spec,), amax_arrays) - - for trajectory in results["plot_data"]: - if _grid_locator_match( - trajectory["params"], trajectory["pind"], (params,), full_inds - ): - matches.append(trajectory) - - if len(matches) > 1: - raise ValueError("Specification is nonunique; matched multiple results") - if len(matches) == 0: - raise ValueError("Could not find a match") - return matches[0]["data"] - - -_EqTester = TypeVar("_EqTester") - - -def _param_normalize(val: _EqTester) -> _EqTester | str: - if type(val).__eq__ == object.__eq__: - return repr(val) - else: - return val diff --git a/tests/test_all.py b/tests/test_all.py index 962d0df..e6b3eab 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from gen_experiments import gridsearch, utils +import gen_experiments.gridsearch.typing +from gen_experiments import gridsearch def test_thin_indexing(): @@ -82,9 +83,9 @@ def test_marginalize_grid_views(): def test_argopt_tuple_axis(): - arr = np.arange(16).reshape(2, 2, 2, 2) + arr = np.arange(16, dtype=np.float_).reshape(2, 2, 2, 2) arr[0, 0, 0, 0] = 1000 - result = utils._argopt(arr, (1, 3)) + result = gridsearch._argopt(arr, (1, 3)) expected = np.array( [[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 0, 1), (1, 1, 1, 1)]], dtype="i,i,i,i" ) @@ -92,18 +93,18 @@ def test_argopt_tuple_axis(): def test_argopt_empty_tuple_axis(): - arr = np.arange(4).reshape(4) - result = utils._argopt(arr, ()) + arr = np.arange(4, dtype=np.float_).reshape(4) + result = gridsearch._argopt(arr, ()) expected = np.array([(0,), (1,), (2,), (3,)], dtype=[("f0", "i")]) np.testing.assert_array_equal(result, expected) - result = utils._argopt(arr, None) + result = gridsearch._argopt(arr, None) pass def test_argopt_int_axis(): - arr = np.arange(8).reshape(2, 2, 2) + arr = np.arange(8, dtype=np.float_).reshape(2, 2, 2) arr[0, 0, 0] = 1000 - result = utils._argopt(arr, 1) + result = gridsearch._argopt(arr, 1) expected = np.array([[(0, 0, 0), (0, 1, 1)], [(1, 1, 0), (1, 1, 1)]], dtype="i,i,i") np.testing.assert_array_equal(result, expected) @@ -112,19 +113,21 @@ def test_index_in(): match_me = (1, ..., slice(None), 3) good = [(1, 2, 1, 3), (1, 1, 3)] for g in good: - assert utils._index_in(g, match_me) + assert gridsearch._index_in(g, match_me) bad = [(1, 3), (1, 1, 2), (1, 1, 1, 2)] for b in bad: - assert not utils._index_in(b, match_me) + assert not gridsearch._index_in(b, match_me) def test_index_in_errors(): with pytest.raises(ValueError): - utils._index_in((1,), (slice(-1),)) + gridsearch._index_in((1,), (slice(-1),)) def test_flatten_nested_dict(): - deep = utils.NestedDict(a=utils.NestedDict(b=1)) + deep = gen_experiments.gridsearch.typing.NestedDict( + a=gen_experiments.gridsearch.typing.NestedDict(b=1) + ) result = deep.flatten() assert deep != result expected = {"a.b": 1} @@ -154,7 +157,7 @@ def test_grid_locator_match(): ), ] for param_spec, ind_spec in good_specs: - assert utils._grid_locator_match(m_params, m_ind, param_spec, ind_spec) + assert gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) bad_specs = [ ((), ((0, 1),)), @@ -164,7 +167,7 @@ def test_grid_locator_match(): ((), ()), ] for param_spec, ind_spec in bad_specs: - assert not utils._grid_locator_match(m_params, m_ind, param_spec, ind_spec) + assert not gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) def test_amax_to_full_inds(): @@ -181,7 +184,9 @@ def test_amax_to_full_inds(): def test_flatten_nested_bad_dict(): with pytest.raises(TypeError, match="keywords must be strings"): - utils.NestedDict(**{1: utils.NestedDict(b=1)}) + gen_experiments.gridsearch.typing.NestedDict( + **{1: gen_experiments.gridsearch.typing.NestedDict(b=1)} + ) with pytest.raises(TypeError, match="Only string keys allowed"): - deep = utils.NestedDict(a={1: 1}) + deep = gen_experiments.gridsearch.typing.NestedDict(a={1: 1}) deep.flatten() From c51cea37d4bb7a1dd5bb7c83adb0159f2c4b8aa9 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:33:50 +0000 Subject: [PATCH 20/46] type(gridsearch): Clarify types in SkinnySpec --- src/gen_experiments/config.py | 16 +++++++++++----- src/gen_experiments/gridsearch/__init__.py | 22 ++++++---------------- src/gen_experiments/gridsearch/typing.py | 16 +++++++++++++++- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 6a5654c..2b4d4db 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -1,10 +1,16 @@ +from collections.abc import Iterable from typing import TypeVar import numpy as np import pysindy as ps from gen_experiments.data import _signal_avg_power -from gen_experiments.gridsearch.typing import NestedDict, SeriesDef, SeriesList +from gen_experiments.gridsearch.typing import ( + NestedDict, + SeriesDef, + SeriesList, + SkinnySpecs, +) from gen_experiments.plotting import _PlotPrefs from gen_experiments.utils import FullSINDyTrialData @@ -302,7 +308,7 @@ def addn(x): "duration-absnoise": ["sim_params.t_end", "sim_params.noise_abs"], "rel_noise": ["sim_params.t_end", "sim_params.noise_rel"], } -grid_vals = { +grid_vals: dict[str, list[Iterable]] = { "test": [[5, 10, 15, 20]], "abs_noise": [[0.1, 0.5, 1, 2, 4, 8]], "abs_noise-kalman": [[0.1, 0.5, 1, 2, 4, 8], [0.1, 0.5, 1, 2, 4, 8]], @@ -320,7 +326,7 @@ def addn(x): "lorenzk": ["plot", "plot", "max"], "plot2": ["plot", "plot"], } -diff_series = { +diff_series: dict[str, SeriesDef] = { "kalman1": SeriesDef( "Kalman", diff_params["kalman"], @@ -376,7 +382,7 @@ def addn(x): [[5, 8, 12, 15]], ), } -series_params = { +series_params: dict[str, SeriesList] = { "test": SeriesList( "diff_params", "Differentiation Method", @@ -435,7 +441,7 @@ def addn(x): } -skinny_specs = { +skinny_specs: dict[str, SkinnySpecs] = { "exp3": ( ("sim_params.noise_abs", "diff_params.meas_var"), ((identity,), (identity,)), diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 89e5d3b..f2fdf13 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -4,16 +4,7 @@ from logging import getLogger from pprint import pformat from types import EllipsisType as ellipsis -from typing import ( - Annotated, - Any, - Callable, - Collection, - Optional, - Sequence, - TypeVar, - cast, -) +from typing import Annotated, Any, Collection, Optional, Sequence, TypeVar, cast import matplotlib.pyplot as plt import numpy as np @@ -33,9 +24,11 @@ GridsearchResult, GridsearchResultDetails, NestedDict, + OtherSliceDef, SavedGridPoint, SeriesDef, SeriesList, + SkinnySpecs, ) pformat = partial(pformat, indent=4, sort_dicts=True) @@ -43,9 +36,6 @@ name = "gridsearch" lookup_dict = vars(config) -OtherSliceDef = tuple[int | Callable] -SkinnySpecs = Optional[tuple[tuple[str, ...], tuple[OtherSliceDef, ...]]] - def _amax_to_full_inds( amax_inds: Collection[tuple[int | slice, int] | ellipsis] | ellipsis, @@ -149,10 +139,10 @@ def run( grid_vals: list[Sequence], grid_decisions: list[str], other_params: dict, + skinny_specs: SkinnySpecs, series_params: Optional[SeriesList] = None, metrics: Sequence[str] = (), plot_prefs: _PlotPrefs = _PlotPrefs(True, False, ()), - skinny_specs: SkinnySpecs = None, ) -> GridsearchResultDetails: """Run a grid-search wrapper of an experiment. @@ -471,7 +461,7 @@ def _ndindex_skinny( full_indexes = np.ndindex(shape) thin_slices = cast(Sequence[OtherSliceDef], thin_slices) - def ind_checker(multi_index): + def ind_checker(multi_index: tuple[int, ...]) -> bool: """Check if a multi_index meets thin index criteria""" matches = [] # check whether multi_index matches criteria of any thin_axis @@ -484,7 +474,7 @@ def ind_checker(multi_index): if callable(slice_ind): slice_ind = slice_ind(multi_index[ax1]) # would check: "== slice_ind", but must allow slice_ind = -1 - match *= multi_index[ax2] == range(shape[ax2])[slice_ind] + match &= multi_index[ax2] == range(shape[ax2])[slice_ind] matches.append(match) return any(matches) diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index 1aa2c98..b78efbb 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -2,11 +2,25 @@ from collections.abc import Iterable from dataclasses import dataclass, field from types import EllipsisType as ellipsis -from typing import Annotated, Any, Collection, Optional, Sequence, TypedDict, TypeVar +from typing import ( + Annotated, + Any, + Callable, + Collection, + Optional, + Sequence, + TypedDict, + TypeVar, +) import numpy as np from numpy.typing import NDArray +OtherSliceDef = tuple[(int | Callable[[int], int]), ...] +"""For a particular index of one gridsearch axis, which indexes of other axes +should be included.""" +SkinnySpecs = tuple[tuple[str, ...], tuple[OtherSliceDef, ...]] + @dataclass(frozen=True) class GridLocator: From f3425fe7534b72c4c565bb3c2131836a97cca160 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:35:14 +0000 Subject: [PATCH 21/46] feat(gridsearch): Added find_gridpoints() --- src/gen_experiments/gridsearch/__init__.py | 53 +++++++++++++++++++++- src/gen_experiments/gridsearch/typing.py | 14 ++++-- src/gen_experiments/plotting.py | 2 +- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index f2fdf13..5963da9 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -21,6 +21,7 @@ from .typing import ( ExpResult, + GridLocator, GridsearchResult, GridsearchResultDetails, NestedDict, @@ -141,7 +142,7 @@ def run( other_params: dict, skinny_specs: SkinnySpecs, series_params: Optional[SeriesList] = None, - metrics: Sequence[str] = (), + metrics: tuple[str, ...] = (), plot_prefs: _PlotPrefs = _PlotPrefs(True, False, ()), ) -> GridsearchResultDetails: """Run a grid-search wrapper of an experiment. @@ -209,7 +210,6 @@ def run( curr_results, grid_data = base_ex.run( new_seed, **curr_other_params, display=False, return_all=True ) - grid_data: ExpResult intermediate_data.append( {"params": curr_other_params.flatten(), "pind": ind, "data": grid_data} ) @@ -290,6 +290,7 @@ def run( }, "metrics": metrics, "grid_params": grid_params, + "plot_params": [decide for decide in grid_decisions if decide == "plot"], "grid_vals": grid_vals, "main": max( grid[main_metric_ind].max() @@ -571,3 +572,51 @@ def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> if curr_ax == len(base): return True return False + + +def find_gridpoints( + find: GridLocator, where: GridsearchResultDetails +) -> list[SavedGridPoint]: + """Find results wrapped by gridsearch that match criteria + + Args: + find: the criteria + where: The overall results of the gridsearch + + Returns: + A list of the wrapped results, representing points in the gridsearch. + """ + results: list[SavedGridPoint] = [] + partial_match: list[tuple[int, ...]] = [] + if find.metric is ...: + metric_sl = slice(None) + else: + ind = where["metrics"].index(find.metric) + metric_sl = slice(ind, ind + 1) + if find.keep_axis is ...: + keep_axis_sl = slice(None) + keep_el_sl = slice(None) + else: + ind = where["plot_params"].index(find.keep_axis[0]) + keep_axis_sl = slice(ind, ind + 1) + if find.keep_axis[1] is ...: + keep_el_sl = slice(None) + else: + ind = find.keep_axis[1] + keep_el_sl = slice(ind, ind + 1) + for ser in where["series_data"].values(): + ser = ser[keep_axis_sl] + for _, amax_arr in ser: + amax_want = amax_arr[metric_sl, keep_el_sl].flatten() + partial_match.extend(amax_want) + + params_or = { + k: _param_normalize(v) for params in find.param_match for k, v in params.items() + } + for point in where["plot_data"]: + if point["pind"] in partial_match and all( + _param_normalize(point["params"][param]) == value + for param, value in params_or.items() + ): + results.append(point) + return results diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index b78efbb..b2df514 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -44,7 +44,7 @@ class GridLocator: """ metric: str | ellipsis = field(default=...) - keep_axis: str | tuple[str, int] | ellipsis = field(default=...) + keep_axis: tuple[str, int | ellipsis] | ellipsis = field(default=...) param_match: Collection[dict[str, Any]] = field(default=()) @@ -64,6 +64,14 @@ class GridLocator: class SavedGridPoint(TypedDict): + """The results at a point in the gridsearch. + + Args: + params: the full list of parameters identifying this variant + pind: the full index in the series' grid + data: the results of the experiment + """ + params: dict pind: tuple[int] data: ExpResult @@ -73,10 +81,10 @@ class GridsearchResultDetails(TypedDict): system: str plot_data: list[SavedGridPoint] series_data: dict[str, SeriesData] - metrics: list[str] + metrics: tuple[str, ...] grid_params: list[str] + plot_params: list[str] grid_vals: list[Sequence] - grid_axes: dict[str, Collection[float]] main: float diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 461da5a..153724a 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -30,7 +30,7 @@ class _PlotPrefs: plot: bool = True rel_noise: Literal[False] | Callable = False grid_params_match: Collection[dict[str, Any]] = field(default_factory=lambda: ()) - grid_ind_match: Collection[tuple[str | slice, int]] | ellipsis = field( + grid_ind_match: Collection[tuple[int | slice, int]] | ellipsis = field( default_factory=lambda: ... ) From abfa77ffd7d4e8d9847b671ce67eeb9a2f4c0508 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:56:39 +0000 Subject: [PATCH 22/46] test(gridsearch): Test find_gridpoints with GridLocator --- src/gen_experiments/gridsearch/typing.py | 2 +- tests/test_all.py | 51 ++++++++++++++++++++++-- tests/test_gridsearch.py | 0 3 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 tests/test_gridsearch.py diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index b2df514..a0b8103 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -73,7 +73,7 @@ class SavedGridPoint(TypedDict): """ params: dict - pind: tuple[int] + pind: tuple[int, ...] data: ExpResult diff --git a/tests/test_all.py b/tests/test_all.py index e6b3eab..2f61795 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -3,6 +3,11 @@ import gen_experiments.gridsearch.typing from gen_experiments import gridsearch +from gen_experiments.gridsearch.typing import ( + GridsearchResultDetails, + SavedGridPoint, + SeriesData, +) def test_thin_indexing(): @@ -170,6 +175,46 @@ def test_grid_locator_match(): assert not gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) +def test_find_gridpoints(): + exact_locator = gridsearch.GridLocator( + "mse", ("sim_params.t_end", ...), [{"diff_params.alpha": 0.1}] + ) + callable_locator = gridsearch.GridLocator( + ..., ..., [{"diff_params.alpha": lambda x: x < 0.2}] + ) + want: SavedGridPoint = { + "params": {"diff_params.alpha": 0.1}, + "pind": (1,), + "data": {}, + } + dont_want: SavedGridPoint = { + "params": {"diff_params.alpha": 0.2}, + "pind": (0,), + "data": {}, + } + max_amax: SeriesData = [ + (np.ones((2, 2)), np.array([[(1,), (1,)], [(0,), (0,)]])), + (np.ones((2, 2)), np.array([[(0,), (0,)], [(0,), (0,)]])), + ] + full_details: GridsearchResultDetails = { + "system": "sho", + "plot_data": [want, dont_want], + "series_data": {"foo": max_amax}, + "metrics": ("mse", "mae"), + "grid_params": ["sim_params.t_end", "sim_params.noise"], + "plot_params": ["sim_params.t_end", "sim_params.noise"], + "grid_vals": [[1, 2, 3], [4, 5, 6]], + "main": 1, + } + results = gridsearch.find_gridpoints(exact_locator, full_details) + for result in results: + assert result is want + + results = gridsearch.find_gridpoints(callable_locator, full_details) + for result in results: + assert result is want + + def test_amax_to_full_inds(): amax_inds = ((1, 1), (slice(None), 0)) arr = np.array([[(0, 0), (0, 1)], [(1, 0), (1, 1)]], dtype="i,i") @@ -183,10 +228,10 @@ def test_amax_to_full_inds(): def test_flatten_nested_bad_dict(): + nested = {1: gen_experiments.gridsearch.typing.NestedDict(b=1)} + # Testing the very thing that causes a typing error, thus ignoring with pytest.raises(TypeError, match="keywords must be strings"): - gen_experiments.gridsearch.typing.NestedDict( - **{1: gen_experiments.gridsearch.typing.NestedDict(b=1)} - ) + gen_experiments.gridsearch.typing.NestedDict(**nested) # type: ignore with pytest.raises(TypeError, match="Only string keys allowed"): deep = gen_experiments.gridsearch.typing.NestedDict(a={1: 1}) deep.flatten() diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py new file mode 100644 index 0000000..e69de29 From 787247846c24d723dc5f5d6fdf1706bd6b6ba986 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 10 Mar 2024 17:19:13 +0000 Subject: [PATCH 23/46] types: Check more files in CI --- pyproject.toml | 8 +++++++- src/gen_experiments/__init__.py | 3 +-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 81edcf7..c223195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,13 @@ addopts = '-m "not slow"' markers = ["slow"] [tool.mypy] -files = ["src/gen_experiments/__init__.py", "src/gen_experiments/utils.py"] +files = [ + "src/gen_experiments/__init__.py", + "src/gen_experiments/utils.py", + "src/gen_experiments/gridsearch/typing.py", + "tests/test_all.py", + "tests/test_gridsearch.py", +] [[tool.mypy.overrides]] module="auto_ks.*" diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index 61f64fa..cd4a3dc 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -7,8 +7,7 @@ from numpy.typing import NDArray from pysindy import BaseDifferentiation, FiniteDifference, SINDy -from . import gridsearch # type: ignore -from . import odes, pdes +from . import gridsearch, odes, pdes from .utils import SINDyTrialData, make_model # noqa: F401 this_module = importlib.import_module(__name__) From b0b3517f7c3a373a14f7acd16206b6a06c531cd0 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 10 Mar 2024 17:28:28 +0000 Subject: [PATCH 24/46] test: move gridsearch tests to their own file, move NestedDict into core --- src/gen_experiments/config.py | 8 +- src/gen_experiments/gridsearch/__init__.py | 3 +- src/gen_experiments/gridsearch/typing.py | 55 ----- src/gen_experiments/typing.py | 55 +++++ tests/test_all.py | 226 +-------------------- tests/test_gridsearch.py | 207 +++++++++++++++++++ 6 files changed, 270 insertions(+), 284 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 2b4d4db..9847370 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -5,13 +5,9 @@ import pysindy as ps from gen_experiments.data import _signal_avg_power -from gen_experiments.gridsearch.typing import ( - NestedDict, - SeriesDef, - SeriesList, - SkinnySpecs, -) +from gen_experiments.gridsearch.typing import SeriesDef, SeriesList, SkinnySpecs from gen_experiments.plotting import _PlotPrefs +from gen_experiments.typing import NestedDict from gen_experiments.utils import FullSINDyTrialData T = TypeVar("T") diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 5963da9..009908b 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -16,7 +16,7 @@ from gen_experiments import config from gen_experiments.odes import plot_ode_panel from gen_experiments.plotting import _PlotPrefs -from gen_experiments.typing import FloatND +from gen_experiments.typing import FloatND, NestedDict from gen_experiments.utils import simulate_test_data from .typing import ( @@ -24,7 +24,6 @@ GridLocator, GridsearchResult, GridsearchResultDetails, - NestedDict, OtherSliceDef, SavedGridPoint, SeriesDef, diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index a0b8103..ace8c62 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -1,4 +1,3 @@ -from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass, field from types import EllipsisType as ellipsis @@ -167,57 +166,3 @@ class SeriesList: param_name: Optional[str] print_name: Optional[str] series_list: list[SeriesDef] - - -class NestedDict(defaultdict): - """A dictionary that splits all keys by ".", creating a sub-dict. - - Args: see superclass - - Example: - - >>> foo = NestedDict("a.b"=1) - >>> foo["a.c"] = 2 - >>> foo["a"]["b"] - 1 - """ - - def __missing__(self, key): - try: - prefix, subkey = key.split(".", 1) - except ValueError: - raise KeyError(key) - return self[prefix][subkey] - - def __setitem__(self, key, value): - if "." in key: - prefix, suffix = key.split(".", 1) - if self.get(prefix) is None: - self[prefix] = NestedDict() - return self[prefix].__setitem__(suffix, value) - else: - return super().__setitem__(key, value) - - def update(self, other: dict): # type: ignore - try: - for k, v in other.items(): - self.__setitem__(k, v) - except: # noqa: E722 - super().update(other) - - def flatten(self): - """Flattens a nested dictionary without mutating. Returns new dict""" - - def _flatten(nested_d: dict) -> dict: - new = {} - for key, value in nested_d.items(): - if not isinstance(key, str): - raise TypeError("Only string keys allowed in flattening") - if not isinstance(value, dict): - new[key] = value - continue - for sub_key, sub_value in _flatten(value).items(): - new[key + "." + sub_key] = sub_value - return new - - return _flatten(self) diff --git a/src/gen_experiments/typing.py b/src/gen_experiments/typing.py index e797377..815ccd8 100644 --- a/src/gen_experiments/typing.py +++ b/src/gen_experiments/typing.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import TypeVar import numpy as np @@ -8,3 +9,57 @@ Float2D = np.ndarray[tuple[int, int], NpFlt] Shape = TypeVar("Shape", bound=tuple[int, ...]) FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] + + +class NestedDict(defaultdict): + """A dictionary that splits all keys by ".", creating a sub-dict. + + Args: see superclass + + Example: + + >>> foo = NestedDict("a.b"=1) + >>> foo["a.c"] = 2 + >>> foo["a"]["b"] + 1 + """ + + def __missing__(self, key): + try: + prefix, subkey = key.split(".", 1) + except ValueError: + raise KeyError(key) + return self[prefix][subkey] + + def __setitem__(self, key, value): + if "." in key: + prefix, suffix = key.split(".", 1) + if self.get(prefix) is None: + self[prefix] = NestedDict() + return self[prefix].__setitem__(suffix, value) + else: + return super().__setitem__(key, value) + + def update(self, other: dict): # type: ignore + try: + for k, v in other.items(): + self.__setitem__(k, v) + except: # noqa: E722 + super().update(other) + + def flatten(self): + """Flattens a nested dictionary without mutating. Returns new dict""" + + def _flatten(nested_d: dict) -> dict: + new = {} + for key, value in nested_d.items(): + if not isinstance(key, str): + raise TypeError("Only string keys allowed in flattening") + if not isinstance(value, dict): + new[key] = value + continue + for sub_key, sub_value in _flatten(value).items(): + new[key + "." + sub_key] = sub_value + return new + + return _flatten(self) diff --git a/tests/test_all.py b/tests/test_all.py index 2f61795..13274fd 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,237 +1,21 @@ -import numpy as np import pytest -import gen_experiments.gridsearch.typing -from gen_experiments import gridsearch -from gen_experiments.gridsearch.typing import ( - GridsearchResultDetails, - SavedGridPoint, - SeriesData, -) - - -def test_thin_indexing(): - result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (-1,)))) - expected = { - (0, 0, 0), - (0, 1, 0), - (1, 0, 0), - (1, 1, 0), - (1, 0, 1), - (1, 1, 1), - } - assert result == expected - - -def test_thin_indexing_default(): - result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), None)) - expected = { - (0, 0, 0), - (0, 1, 0), - (1, 0, 0), - (1, 1, 0), - (0, 0, 1), - (0, 1, 1), - } - assert result == expected - - -def test_thin_indexing_callable(): - result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (lambda x: x,)))) - expected = { - (0, 0, 0), - (0, 1, 0), - (1, 0, 0), - (1, 1, 0), - (1, 0, 1), - (1, 1, 1), - } - assert result == expected - - -def test_curr_skinny_specs(): - grid_params = ["a", "c", "e", "f"] - skinny_specs = ( - ("a", "b", "c", "d", "e"), - ((1, 2, 3, 4), (0, 2, 3, 4), (0, 1, 3, 4), (0, 1, 2, 4), (0, 1, 2, 3)), - ) - ind_skinny, where_others = gridsearch._curr_skinny_specs(skinny_specs, grid_params) - assert ind_skinny == [0, 1, 2] - assert where_others == ((2, 4), (0, 4), (0, 2)) - - -def test_marginalize_grid_views(): - arr = np.arange(16, dtype=np.float_).reshape( - 2, 2, 2, 2 - ) # (metrics, param1, param2, param3) - arr[0, 0, 0, 0] = 1000 - arr[-1, -1, -1, 0] = -1000 - arr[0, 0, 0, 1] = np.nan - grid_decisions = ["plot", "max", "plot"] - opts = ["max", "min"] - res_val, res_ind = gridsearch._marginalize_grid_views(grid_decisions, arr, opts) - assert len(res_val) == len([dec for dec in grid_decisions if dec == "plot"]) - expected_val = [ - np.array([[1000, 7], [8, -1000]]), - np.array([[1000, 7], [-1000, 9]]), - ] - for result, expected in zip(res_val, expected_val): - np.testing.assert_array_equal(result, expected) - - ts = "i,i,i,i" - expected_ind = [ - np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 0, 0, 0), (1, 1, 1, 0)]], ts), - np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 1, 0), (1, 0, 0, 1)]], ts), - ] - for result, expected in zip(res_ind, expected_ind): - np.testing.assert_array_equal(result, expected) - - -def test_argopt_tuple_axis(): - arr = np.arange(16, dtype=np.float_).reshape(2, 2, 2, 2) - arr[0, 0, 0, 0] = 1000 - result = gridsearch._argopt(arr, (1, 3)) - expected = np.array( - [[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 0, 1), (1, 1, 1, 1)]], dtype="i,i,i,i" - ) - np.testing.assert_array_equal(result, expected) - - -def test_argopt_empty_tuple_axis(): - arr = np.arange(4, dtype=np.float_).reshape(4) - result = gridsearch._argopt(arr, ()) - expected = np.array([(0,), (1,), (2,), (3,)], dtype=[("f0", "i")]) - np.testing.assert_array_equal(result, expected) - result = gridsearch._argopt(arr, None) - pass - - -def test_argopt_int_axis(): - arr = np.arange(8, dtype=np.float_).reshape(2, 2, 2) - arr[0, 0, 0] = 1000 - result = gridsearch._argopt(arr, 1) - expected = np.array([[(0, 0, 0), (0, 1, 1)], [(1, 1, 0), (1, 1, 1)]], dtype="i,i,i") - np.testing.assert_array_equal(result, expected) - - -def test_index_in(): - match_me = (1, ..., slice(None), 3) - good = [(1, 2, 1, 3), (1, 1, 3)] - for g in good: - assert gridsearch._index_in(g, match_me) - bad = [(1, 3), (1, 1, 2), (1, 1, 1, 2)] - for b in bad: - assert not gridsearch._index_in(b, match_me) - - -def test_index_in_errors(): - with pytest.raises(ValueError): - gridsearch._index_in((1,), (slice(-1),)) +from gen_experiments.typing import NestedDict def test_flatten_nested_dict(): - deep = gen_experiments.gridsearch.typing.NestedDict( - a=gen_experiments.gridsearch.typing.NestedDict(b=1) - ) + deep = NestedDict(a=NestedDict(b=1)) result = deep.flatten() assert deep != result expected = {"a.b": 1} assert result == expected -def test_grid_locator_match(): - m_params = {"sim_params.t_end": 10, "foo": 1} - m_ind = (0, 1) - # Effectively testing the clause: (x OR y OR ...) AND (a OR b OR ...) - # Note: OR() with no args is falsy - # also note first index is stripped ind_spec - good_specs = [ - (({"sim_params.t_end": 10},), ((1, 0, 1),)), - (({"sim_params.t_end": 10},), ((1, 0, 1), (1, 0, ...))), - (({"sim_params.t_end": 10}, {"foo": 1}), ((1, 0, 1),)), - (({"sim_params.t_end": 10}, {"bar: 1"}), ((1, 0, 1),)), - ( - ({"sim_params.t_end": 10},), - ( - (1, 0, 1), - ( - 1, - 1, - ), - ), - ), - ] - for param_spec, ind_spec in good_specs: - assert gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) - - bad_specs = [ - ((), ((0, 1),)), - (({"sim_params.t_end": 10},), ()), - (({"sim_params.t_end": 9},), ((1, 0, 1),)), - (({"sim_params.t_end": 10},), ((1, 0, 0),)), - ((), ()), - ] - for param_spec, ind_spec in bad_specs: - assert not gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) - - -def test_find_gridpoints(): - exact_locator = gridsearch.GridLocator( - "mse", ("sim_params.t_end", ...), [{"diff_params.alpha": 0.1}] - ) - callable_locator = gridsearch.GridLocator( - ..., ..., [{"diff_params.alpha": lambda x: x < 0.2}] - ) - want: SavedGridPoint = { - "params": {"diff_params.alpha": 0.1}, - "pind": (1,), - "data": {}, - } - dont_want: SavedGridPoint = { - "params": {"diff_params.alpha": 0.2}, - "pind": (0,), - "data": {}, - } - max_amax: SeriesData = [ - (np.ones((2, 2)), np.array([[(1,), (1,)], [(0,), (0,)]])), - (np.ones((2, 2)), np.array([[(0,), (0,)], [(0,), (0,)]])), - ] - full_details: GridsearchResultDetails = { - "system": "sho", - "plot_data": [want, dont_want], - "series_data": {"foo": max_amax}, - "metrics": ("mse", "mae"), - "grid_params": ["sim_params.t_end", "sim_params.noise"], - "plot_params": ["sim_params.t_end", "sim_params.noise"], - "grid_vals": [[1, 2, 3], [4, 5, 6]], - "main": 1, - } - results = gridsearch.find_gridpoints(exact_locator, full_details) - for result in results: - assert result is want - - results = gridsearch.find_gridpoints(callable_locator, full_details) - for result in results: - assert result is want - - -def test_amax_to_full_inds(): - amax_inds = ((1, 1), (slice(None), 0)) - arr = np.array([[(0, 0), (0, 1)], [(1, 0), (1, 1)]], dtype="i,i") - amax_arrays = [[arr, arr], [arr]] - result = gridsearch._amax_to_full_inds(amax_inds, amax_arrays) - expected = {(0, 0), (1, 1), (1, 0)} - assert result == expected - result = gridsearch._amax_to_full_inds(..., amax_arrays) - expected |= {(0, 1)} - return result == expected - - def test_flatten_nested_bad_dict(): - nested = {1: gen_experiments.gridsearch.typing.NestedDict(b=1)} + nested = {1: NestedDict(b=1)} # Testing the very thing that causes a typing error, thus ignoring with pytest.raises(TypeError, match="keywords must be strings"): - gen_experiments.gridsearch.typing.NestedDict(**nested) # type: ignore + NestedDict(**nested) # type: ignore with pytest.raises(TypeError, match="Only string keys allowed"): - deep = gen_experiments.gridsearch.typing.NestedDict(a={1: 1}) + deep = NestedDict(a={1: 1}) deep.flatten() diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index e69de29..10f1841 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -0,0 +1,207 @@ +import numpy as np +import pytest + +from gen_experiments import gridsearch +from gen_experiments.gridsearch.typing import ( + GridsearchResultDetails, + SavedGridPoint, + SeriesData, +) + + +def test_thin_indexing(): + result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (-1,)))) + expected = { + (0, 0, 0), + (0, 1, 0), + (1, 0, 0), + (1, 1, 0), + (1, 0, 1), + (1, 1, 1), + } + assert result == expected + + +def test_thin_indexing_default(): + result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), None)) + expected = { + (0, 0, 0), + (0, 1, 0), + (1, 0, 0), + (1, 1, 0), + (0, 0, 1), + (0, 1, 1), + } + assert result == expected + + +def test_thin_indexing_callable(): + result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (lambda x: x,)))) + expected = { + (0, 0, 0), + (0, 1, 0), + (1, 0, 0), + (1, 1, 0), + (1, 0, 1), + (1, 1, 1), + } + assert result == expected + + +def test_curr_skinny_specs(): + grid_params = ["a", "c", "e", "f"] + skinny_specs = ( + ("a", "b", "c", "d", "e"), + ((1, 2, 3, 4), (0, 2, 3, 4), (0, 1, 3, 4), (0, 1, 2, 4), (0, 1, 2, 3)), + ) + ind_skinny, where_others = gridsearch._curr_skinny_specs(skinny_specs, grid_params) + assert ind_skinny == [0, 1, 2] + assert where_others == ((2, 4), (0, 4), (0, 2)) + + +def test_marginalize_grid_views(): + arr = np.arange(16, dtype=np.float_).reshape( + 2, 2, 2, 2 + ) # (metrics, param1, param2, param3) + arr[0, 0, 0, 0] = 1000 + arr[-1, -1, -1, 0] = -1000 + arr[0, 0, 0, 1] = np.nan + grid_decisions = ["plot", "max", "plot"] + opts = ["max", "min"] + res_val, res_ind = gridsearch._marginalize_grid_views(grid_decisions, arr, opts) + assert len(res_val) == len([dec for dec in grid_decisions if dec == "plot"]) + expected_val = [ + np.array([[1000, 7], [8, -1000]]), + np.array([[1000, 7], [-1000, 9]]), + ] + for result, expected in zip(res_val, expected_val): + np.testing.assert_array_equal(result, expected) + + ts = "i,i,i,i" + expected_ind = [ + np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 0, 0, 0), (1, 1, 1, 0)]], ts), + np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 1, 0), (1, 0, 0, 1)]], ts), + ] + for result, expected in zip(res_ind, expected_ind): + np.testing.assert_array_equal(result, expected) + + +def test_argopt_tuple_axis(): + arr = np.arange(16, dtype=np.float_).reshape(2, 2, 2, 2) + arr[0, 0, 0, 0] = 1000 + result = gridsearch._argopt(arr, (1, 3)) + expected = np.array( + [[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 0, 1), (1, 1, 1, 1)]], dtype="i,i,i,i" + ) + np.testing.assert_array_equal(result, expected) + + +def test_argopt_empty_tuple_axis(): + arr = np.arange(4, dtype=np.float_).reshape(4) + result = gridsearch._argopt(arr, ()) + expected = np.array([(0,), (1,), (2,), (3,)], dtype=[("f0", "i")]) + np.testing.assert_array_equal(result, expected) + result = gridsearch._argopt(arr, None) + pass + + +def test_argopt_int_axis(): + arr = np.arange(8, dtype=np.float_).reshape(2, 2, 2) + arr[0, 0, 0] = 1000 + result = gridsearch._argopt(arr, 1) + expected = np.array([[(0, 0, 0), (0, 1, 1)], [(1, 1, 0), (1, 1, 1)]], dtype="i,i,i") + np.testing.assert_array_equal(result, expected) + + +def test_index_in(): + match_me = (1, ..., slice(None), 3) + good = [(1, 2, 1, 3), (1, 1, 3)] + for g in good: + assert gridsearch._index_in(g, match_me) + bad = [(1, 3), (1, 1, 2), (1, 1, 1, 2)] + for b in bad: + assert not gridsearch._index_in(b, match_me) + + +def test_index_in_errors(): + with pytest.raises(ValueError): + gridsearch._index_in((1,), (slice(-1),)) + + +def test_amax_to_full_inds(): + amax_inds = ((1, 1), (slice(None), 0)) + arr = np.array([[(0, 0), (0, 1)], [(1, 0), (1, 1)]], dtype="i,i") + amax_arrays = [[arr, arr], [arr]] + result = gridsearch._amax_to_full_inds(amax_inds, amax_arrays) + expected = {(0, 0), (1, 1), (1, 0)} + assert result == expected + result = gridsearch._amax_to_full_inds(..., amax_arrays) + expected |= {(0, 1)} + return result == expected + + +def test_find_gridpoints(): + exact_locator = gridsearch.GridLocator( + "mse", ("sim_params.t_end", ...), [{"diff_params.alpha": 0.1}] + ) + callable_locator = gridsearch.GridLocator( + ..., ..., [{"diff_params.alpha": lambda x: x < 0.2}] + ) + want: SavedGridPoint = { + "params": {"diff_params.alpha": 0.1}, + "pind": (1,), + "data": {}, + } + dont_want: SavedGridPoint = { + "params": {"diff_params.alpha": 0.2}, + "pind": (0,), + "data": {}, + } + max_amax: SeriesData = [ + (np.ones((2, 2)), np.array([[(1,), (1,)], [(0,), (0,)]])), + (np.ones((2, 2)), np.array([[(0,), (0,)], [(0,), (0,)]])), + ] + full_details: GridsearchResultDetails = { + "system": "sho", + "plot_data": [want, dont_want], + "series_data": {"foo": max_amax}, + "metrics": ("mse", "mae"), + "grid_params": ["sim_params.t_end", "sim_params.noise"], + "plot_params": ["sim_params.t_end", "sim_params.noise"], + "grid_vals": [[1, 2, 3], [4, 5, 6]], + "main": 1, + } + results = gridsearch.find_gridpoints(exact_locator, full_details) + for result in results: + assert result is want + + results = gridsearch.find_gridpoints(callable_locator, full_details) + for result in results: + assert result is want + + +def test_grid_locator_match(): + m_params = {"sim_params.t_end": 10, "foo": 1} + m_ind = (0, 1) + # Effectively testing the clause: (x OR y OR ...) AND (a OR b OR ...) + # Note: OR() with no args is falsy + # also note first index is stripped ind_spec + good_specs = [ + (({"sim_params.t_end": 10},), ((1, 0, 1),)), + (({"sim_params.t_end": 10},), ((1, 0, 1), (1, 0, ...))), + (({"sim_params.t_end": 10}, {"foo": 1}), ((1, 0, 1),)), + (({"sim_params.t_end": 10}, {"bar: 1"}), ((1, 0, 1),)), + (({"sim_params.t_end": 10},), ((1, 0, 1), (1, 1))), + ] + for param_spec, ind_spec in good_specs: + assert gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) + + bad_specs = [ + ((), ((0, 1),)), + (({"sim_params.t_end": 10},), ()), + (({"sim_params.t_end": 9},), ((1, 0, 1),)), + (({"sim_params.t_end": 10},), ((1, 0, 0),)), + ((), ()), + ] + for param_spec, ind_spec in bad_specs: + assert not gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) From ff0a8bc5e76b73fb89ddbe6686cad40594dc94e2 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 10 Mar 2024 19:32:08 +0000 Subject: [PATCH 25/46] test(gridsearch): Clarify and fix test_find_gridpoints --- src/gen_experiments/gridsearch/typing.py | 7 ++-- tests/test_gridsearch.py | 42 ++++++++++++++++-------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index ace8c62..eac2fe5 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -30,6 +30,8 @@ class GridLocator: Kalman series that had the best mean squared error as noise was varied. + Logical AND is applied across the metric, keep_axis, AND param_match specifications. + Args: metric: The metric in which to find results. An ellipsis means "any metrics" keep_axis: The grid-varied parameter in which to find results, or a tuple of @@ -38,8 +40,9 @@ class GridLocator: param_match: A collection of dictionaries to match parameter values represented by points in the gridsearch. Dictionary equality is checked for every non-callable value; for callable values, it is applied to the grid - parameters and must return a boolean. Logical OR is applied across the - collection + parameters and must return a boolean. For values whose equality is object + equality (often, mutable objects), the repr is used. Logical OR is applied + across the collection. """ metric: str | ellipsis = field(default=...) diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 10f1841..5503a16 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -1,4 +1,5 @@ import numpy as np +import pysindy as ps import pytest from gen_experiments import gridsearch @@ -140,20 +141,15 @@ def test_amax_to_full_inds(): return result == expected -def test_find_gridpoints(): - exact_locator = gridsearch.GridLocator( - "mse", ("sim_params.t_end", ...), [{"diff_params.alpha": 0.1}] - ) - callable_locator = gridsearch.GridLocator( - ..., ..., [{"diff_params.alpha": lambda x: x < 0.2}] - ) +@pytest.fixture +def gridsearch_results(): want: SavedGridPoint = { - "params": {"diff_params.alpha": 0.1}, + "params": {"diff_params.alpha": 0.1, "opt_params": ps.STLSQ()}, "pind": (1,), "data": {}, } dont_want: SavedGridPoint = { - "params": {"diff_params.alpha": 0.2}, + "params": {"diff_params.alpha": 0.2, "opt_params": ps.SSR()}, "pind": (0,), "data": {}, } @@ -171,13 +167,31 @@ def test_find_gridpoints(): "grid_vals": [[1, 2, 3], [4, 5, 6]], "main": 1, } - results = gridsearch.find_gridpoints(exact_locator, full_details) - for result in results: - assert result is want - - results = gridsearch.find_gridpoints(callable_locator, full_details) + return want, full_details + + +@pytest.mark.parametrize( + "locator", + ( + gridsearch.GridLocator( + "mse", ("sim_params.t_end", ...), [{"diff_params.alpha": 0.1}] + ), + gridsearch.GridLocator( + "mse", ("sim_params.t_end", ...), [{"opt_params": ps.STLSQ()}] + ), + gridsearch.GridLocator(..., ..., [{"diff_params.alpha": lambda x: x < 0.2}]), + gridsearch.GridLocator( + ..., ..., [{"diff_params.alpha": 0.1}, {"diff_params.alpha": 0.3}] + ), + ), + ids=("exact", "object", "callable", "or"), +) +def test_find_gridpoints(gridsearch_results, locator): + want, full_details = gridsearch_results + results = gridsearch.find_gridpoints(locator, full_details) for result in results: assert result is want + assert want in results def test_grid_locator_match(): From 855458df6c3ca63a0b26596271e9b0d55d8e0474 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:23:12 +0000 Subject: [PATCH 26/46] fix: Apply OR correctly in matching param dicts --- src/gen_experiments/gridsearch/__init__.py | 19 +++++++++++-------- src/gen_experiments/gridsearch/typing.py | 2 +- tests/test_gridsearch.py | 4 +--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 009908b..f6d2f80 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -609,13 +609,16 @@ def find_gridpoints( amax_want = amax_arr[metric_sl, keep_el_sl].flatten() partial_match.extend(amax_want) - params_or = { - k: _param_normalize(v) for params in find.param_match for k, v in params.items() - } + params_or = tuple( + {k: _param_normalize(v) for k, v in params_match.items()} + for params_match in find.params_or + ) for point in where["plot_data"]: - if point["pind"] in partial_match and all( - _param_normalize(point["params"][param]) == value - for param, value in params_or.items() - ): - results.append(point) + for params_match in params_or: + if all( + _param_normalize(point["params"][param]) == value + for param, value in params_match.items() + ): + results.append(point) + break return results diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index eac2fe5..91ea1fd 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -47,7 +47,7 @@ class GridLocator: metric: str | ellipsis = field(default=...) keep_axis: tuple[str, int | ellipsis] | ellipsis = field(default=...) - param_match: Collection[dict[str, Any]] = field(default=()) + params_or: Collection[dict[str, Any]] = field(default=()) T = TypeVar("T", bound=np.generic) diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 5503a16..7069541 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -189,9 +189,7 @@ def gridsearch_results(): def test_find_gridpoints(gridsearch_results, locator): want, full_details = gridsearch_results results = gridsearch.find_gridpoints(locator, full_details) - for result in results: - assert result is want - assert want in results + assert [want] == results def test_grid_locator_match(): From dc97a46dad9a49cf7799780855cae6c29952bb65 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:17:23 +0000 Subject: [PATCH 27/46] fix(gridsearch): enable matching callable criteria --- src/gen_experiments/gridsearch/__init__.py | 29 +++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index f6d2f80..c2ff749 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -4,7 +4,16 @@ from logging import getLogger from pprint import pformat from types import EllipsisType as ellipsis -from typing import Annotated, Any, Collection, Optional, Sequence, TypeVar, cast +from typing import ( + Annotated, + Any, + Callable, + Collection, + Optional, + Sequence, + TypeVar, + cast, +) import matplotlib.pyplot as plt import numpy as np @@ -83,6 +92,7 @@ def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: def _param_normalize(val: _EqTester) -> _EqTester | str: + """Allow equality testing of mutable objects with useful reprs""" if type(val).__eq__ == object.__eq__: return repr(val) else: @@ -603,22 +613,35 @@ def find_gridpoints( else: ind = find.keep_axis[1] keep_el_sl = slice(ind, ind + 1) + for ser in where["series_data"].values(): ser = ser[keep_axis_sl] for _, amax_arr in ser: amax_want = amax_arr[metric_sl, keep_el_sl].flatten() partial_match.extend(amax_want) + logger.debug( + f"Found {len(partial_match)} gridpoints that match metric-plot_axis criteria" + ) params_or = tuple( - {k: _param_normalize(v) for k, v in params_match.items()} + {k: v if callable(v) else _param_normalize(v) for k, v in params_match.items()} for params_match in find.params_or ) + + def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: + if callable(criteria): + return criteria(candidate) + else: + return _param_normalize(candidate) == criteria + for point in where["plot_data"]: for params_match in params_or: if all( - _param_normalize(point["params"][param]) == value + check_values(value, point["params"][param]) for param, value in params_match.items() ): results.append(point) break + + logger.debug(f"found {len(results)} points that match all GridLocator criteria") return results From 3b954fc306fba2d299afeb84673169a6d674d5dc Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:34:46 +0000 Subject: [PATCH 28/46] feat(gridsearch): Enable locating a set of metrics or keep_axes Including add a test case for this. --- src/gen_experiments/gridsearch/__init__.py | 73 +++++++++++++++------- src/gen_experiments/gridsearch/typing.py | 33 +++++++--- tests/test_gridsearch.py | 22 ++++--- 3 files changed, 89 insertions(+), 39 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index c2ff749..6e84fc8 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -12,6 +12,7 @@ Optional, Sequence, TypeVar, + Union, cast, ) @@ -61,9 +62,6 @@ def _amax_to_full_inds( all indexers to full gridsearch that are requested by amax_inds """ - def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: - return tuple(int(el) for el in cast(Iterable, tuple_like)) - if amax_inds is ...: # grab each element from arrays in list of lists of arrays return { np_to_primitive(el) @@ -595,30 +593,33 @@ def find_gridpoints( Returns: A list of the wrapped results, representing points in the gridsearch. """ + results: list[SavedGridPoint] = [] - partial_match: list[tuple[int, ...]] = [] - if find.metric is ...: - metric_sl = slice(None) + partial_match: set[tuple[int, ...]] = set() + if find.metrics is ...: + inds_of_metrics = range(len(where["metrics"])) else: - ind = where["metrics"].index(find.metric) - metric_sl = slice(ind, ind + 1) - if find.keep_axis is ...: - keep_axis_sl = slice(None) - keep_el_sl = slice(None) + inds_of_metrics = tuple( + where["metrics"].index(metric) for metric in find.metrics + ) + ax_sizes = { + plot_ax: len(where["grid_vals"][where["grid_params"].index(plot_ax)]) + for plot_ax in where["plot_params"] + } + if ... in find.keep_axes: + keep_axes = _expand_ellipsis_axis(find.keep_axes, ax_sizes) # type: ignore else: - ind = where["plot_params"].index(find.keep_axis[0]) - keep_axis_sl = slice(ind, ind + 1) - if find.keep_axis[1] is ...: - keep_el_sl = slice(None) - else: - ind = find.keep_axis[1] - keep_el_sl = slice(ind, ind + 1) + keep_axes = cast(Collection[tuple[str, tuple[int, ...]]], find.keep_axes) + # No deduplication is done! + keep_axes = tuple( + (where["grid_params"].index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes + ) for ser in where["series_data"].values(): - ser = ser[keep_axis_sl] - for _, amax_arr in ser: - amax_want = amax_arr[metric_sl, keep_el_sl].flatten() - partial_match.extend(amax_want) + for index_of_ax, indexes_in_ax in keep_axes: + amax_arr = ser[index_of_ax][1] + amax_want = amax_arr[np.ix_(inds_of_metrics, indexes_in_ax)].flatten() + partial_match |= {np_to_primitive(el) for el in amax_want} logger.debug( f"Found {len(partial_match)} gridpoints that match metric-plot_axis criteria" ) @@ -634,7 +635,7 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: else: return _param_normalize(candidate) == criteria - for point in where["plot_data"]: + for point in filter(lambda p: p["pind"] in partial_match, where["plot_data"]): for params_match in params_or: if all( check_values(value, point["params"][param]) @@ -645,3 +646,29 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: logger.debug(f"found {len(results)} points that match all GridLocator criteria") return results + + +def _expand_ellipsis_axis( + keep_axis: Union[ + tuple[ellipsis, ellipsis], + tuple[ellipsis, tuple[int, ...]], + tuple[tuple[str, ...], ellipsis], + ], + ax_sizes: dict[str, int], +) -> Collection[tuple[str, tuple[int, ...]]]: + if keep_axis[0] is ... and keep_axis[1] is ...: + # form 1 + return tuple((k, tuple(range(v))) for k, v in ax_sizes.items()) + elif isinstance(keep_axis[1], tuple): + # form 2 + return tuple((k, keep_axis[1]) for k in ax_sizes.keys()) + elif isinstance(keep_axis[0], tuple): + # form 3 + return tuple((k, tuple(range(ax_sizes[k]))) for k in keep_axis[0]) + else: + raise TypeError("Keep_axis does not have an ellipsis or is not a 2-tuple") + + +def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: + """Turn a void that represents a tuple of ints into a tuple of ints""" + return tuple(int(el) for el in cast(Iterable, tuple_like)) diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index 91ea1fd..7e890f7 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -10,6 +10,7 @@ Sequence, TypedDict, TypeVar, + Union, ) import numpy as np @@ -21,6 +22,14 @@ SkinnySpecs = tuple[tuple[str, ...], tuple[OtherSliceDef, ...]] +KeepAxisSpec = Union[ + tuple[ellipsis, ellipsis], # all axes, all indices + tuple[ellipsis, tuple[int, ...]], # all axes, specific indices + tuple[tuple[str, ...], ellipsis], # specific axes, all indices + Collection[tuple[str, tuple[int, ...]]], # specific axes, specific indices +] + + @dataclass(frozen=True) class GridLocator: """A specification of which points in a gridsearch to match. @@ -30,13 +39,23 @@ class GridLocator: Kalman series that had the best mean squared error as noise was varied. - Logical AND is applied across the metric, keep_axis, AND param_match specifications. + Logical AND is applied across the metric, keep_axis, AND param_match + specifications. Args: - metric: The metric in which to find results. An ellipsis means "any metrics" - keep_axis: The grid-varied parameter in which to find results, or a tuple of - that axis and position along that axis. To search a particular value of - that parameter, use the param_match kwarg. An ellipsis means "any axis" + metric: The metric in which to find results. An ellipsis means "any + metrics" + keep_axis: The grid-varied parameter in which to find results and which + index of values for that parameter. To search a particular value of + that parameter, use the param_match kwarg.It can be specified in + several ways: + (a) a tuple of two ellipses, representing all axes, all indices + (b) a tuple of an ellipsis, representing all axes, and a tuple of + ints for specific indices + (c) a tuple of a tuple of strings for specific axes, and an ellipsis + for all indices + (d) a collection of tuples of a string (specific axis) and tuple of + ints (specific indices) param_match: A collection of dictionaries to match parameter values represented by points in the gridsearch. Dictionary equality is checked for every non-callable value; for callable values, it is applied to the grid @@ -45,8 +64,8 @@ class GridLocator: across the collection. """ - metric: str | ellipsis = field(default=...) - keep_axis: tuple[str, int | ellipsis] | ellipsis = field(default=...) + metrics: Collection[str] | ellipsis = field(default=...) + keep_axes: KeepAxisSpec = field(default=(..., ...)) params_or: Collection[dict[str, Any]] = field(default=()) diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 7069541..a3c921f 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -153,9 +153,10 @@ def gridsearch_results(): "pind": (0,), "data": {}, } + tup_dtype = np.dtype([("f0", "i")]) max_amax: SeriesData = [ - (np.ones((2, 2)), np.array([[(1,), (1,)], [(0,), (0,)]])), - (np.ones((2, 2)), np.array([[(0,), (0,)], [(0,), (0,)]])), + (np.ones((2, 2)), np.array([[(1,), (0)], [(0,), (0,)]], dtype=tup_dtype)), + (np.ones((2, 2)), np.array([[(0,), (0,)], [(0,), (0,)]], dtype=tup_dtype)), ] full_details: GridsearchResultDetails = { "system": "sho", @@ -164,7 +165,7 @@ def gridsearch_results(): "metrics": ("mse", "mae"), "grid_params": ["sim_params.t_end", "sim_params.noise"], "plot_params": ["sim_params.t_end", "sim_params.noise"], - "grid_vals": [[1, 2, 3], [4, 5, 6]], + "grid_vals": [[1, 2], [5, 6]], "main": 1, } return want, full_details @@ -174,17 +175,20 @@ def gridsearch_results(): "locator", ( gridsearch.GridLocator( - "mse", ("sim_params.t_end", ...), [{"diff_params.alpha": 0.1}] + ("mse",), (("sim_params.t_end",), ...), [{"diff_params.alpha": 0.1}] ), gridsearch.GridLocator( - "mse", ("sim_params.t_end", ...), [{"opt_params": ps.STLSQ()}] + ("mse",), (("sim_params.t_end",), ...), [{"opt_params": ps.STLSQ()}] ), - gridsearch.GridLocator(..., ..., [{"diff_params.alpha": lambda x: x < 0.2}]), gridsearch.GridLocator( - ..., ..., [{"diff_params.alpha": 0.1}, {"diff_params.alpha": 0.3}] + ..., (..., ...), [{"diff_params.alpha": lambda x: x < 0.2}] + ), + gridsearch.GridLocator(("mse",), {("sim_params.t_end", (0,))}, [{}]), + gridsearch.GridLocator( + ..., (..., ...), [{"diff_params.alpha": 0.1}, {"diff_params.alpha": 0.3}] ), ), - ids=("exact", "object", "callable", "or"), + ids=("exact", "object", "callable", "by_axis", "or"), ) def test_find_gridpoints(gridsearch_results, locator): want, full_details = gridsearch_results @@ -196,7 +200,7 @@ def test_grid_locator_match(): m_params = {"sim_params.t_end": 10, "foo": 1} m_ind = (0, 1) # Effectively testing the clause: (x OR y OR ...) AND (a OR b OR ...) - # Note: OR() with no args is falsy + # Note: OR() with no args is falsy, AND() with no args is thruthy # also note first index is stripped ind_spec good_specs = [ (({"sim_params.t_end": 10},), ((1, 0, 1),)), From adf2d089d562e60538eb23a10eb5327421a64b32 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 12 Mar 2024 15:00:34 +0000 Subject: [PATCH 29/46] tst(gridsearch): Add integration test --- tests/test_gridsearch.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index a3c921f..7f530d0 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -221,3 +221,16 @@ def test_grid_locator_match(): ] for param_spec, ind_spec in bad_specs: assert not gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) + + +def test_gridsearch_mock(): + results = gridsearch.run( + 1, + "none", + grid_params=["foo"], + grid_vals=[[0, 1]], + grid_decisions=["plot"], + other_params={"bar": False}, + metrics=("mse", "mae"), + ) + assert len(results["plot_data"]) == 0 From d17275e5155889db2b07b9944406293a245c8ae1 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 12 Mar 2024 15:03:56 +0000 Subject: [PATCH 30/46] feat(gridsearch): Use find_gridpoints() in gridsearch.run() Also: make skinny_specs optional --- src/gen_experiments/config.py | 93 ++++++----------- src/gen_experiments/gridsearch/__init__.py | 111 ++++++++++----------- src/gen_experiments/plotting.py | 10 +- tests/test_gridsearch.py | 4 +- 4 files changed, 92 insertions(+), 126 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 9847370..b90ece4 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -5,7 +5,12 @@ import pysindy as ps from gen_experiments.data import _signal_avg_power -from gen_experiments.gridsearch.typing import SeriesDef, SeriesList, SkinnySpecs +from gen_experiments.gridsearch.typing import ( + GridLocator, + SeriesDef, + SeriesList, + SkinnySpecs, +) from gen_experiments.plotting import _PlotPrefs from gen_experiments.typing import NestedDict from gen_experiments.utils import FullSINDyTrialData @@ -44,77 +49,43 @@ def addn(x): plot_prefs = { - "test": _PlotPrefs(True, False, ({"sim_params.t_end": 10},)), + "test": _PlotPrefs(), "test-absrel": _PlotPrefs( - True, _convert_abs_rel_noise, ({"sim_params.noise_abs": 1},) + True, _convert_abs_rel_noise, GridLocator(..., {("sim_params.noise_abs", (1,))}) ), "test-absrel2": _PlotPrefs( True, _convert_abs_rel_noise, - ( - {"sim_params.noise_abs": 0.1}, - {"sim_params.noise_abs": 0.5}, - {"sim_params.noise_abs": 1}, - {"sim_params.noise_abs": 2}, - {"sim_params.noise_abs": 4}, - {"sim_params.noise_abs": 8}, + GridLocator( + ..., + (..., ...), + ( + {"sim_params.noise_abs": 0.1}, + {"sim_params.noise_abs": 0.5}, + {"sim_params.noise_abs": 1}, + {"sim_params.noise_abs": 2}, + {"sim_params.noise_abs": 4}, + {"sim_params.noise_abs": 8}, + ), ), ), - "test-absrel3": _PlotPrefs( + "absrel-newloc": _PlotPrefs( True, _convert_abs_rel_noise, - ( - { - "sim_params.noise_abs": 1, - "diff_params.smoother_kws.window_length": 15, - }, - {"sim_params.noise_abs": 1, "diff_params.meas_var": 1}, - {"sim_params.noise_abs": 1, "diff_params.alpha": 1e-2}, + GridLocator( + ["coeff_mse", "coeff_f1"], + (..., (2, 3, 4)), + ( + {"diff_params.kind": "kalman", "diff_params.alpha": None}, + { + "diff_params.kind": "kalman", + "diff_params.alpha": lambda a: isinstance(a, int), + }, + {"diff_params.kind": "trend_filtered"}, + {"diff_params.diffcls": "SmoothedFiniteDifference"}, + ), ), ), - "test-absrel4": _PlotPrefs( - True, - _convert_abs_rel_noise, - ( - { - "sim_params.noise_abs": 1, - "diff_params.smoother_kws.window_length": 15, - }, - {"sim_params.noise_abs": 1, "diff_params.meas_var": 1}, - {"sim_params.noise_abs": 1, "diff_params.alpha": 1e0}, - { - "sim_params.noise_abs": 2, - "diff_params.smoother_kws.window_length": 15, - }, - {"sim_params.noise_abs": 2, "diff_params.meas_var": 4}, - {"sim_params.noise_abs": 2, "diff_params.alpha": 1e-1}, - ), - ), - "test-absrel5": _PlotPrefs( - True, - _convert_abs_rel_noise, - ( - { - "sim_params.noise_abs": 1, - "diff_params.diffcls": "SmoothedFiniteDifference", - }, - {"sim_params.noise_abs": 1, "diff_params.kind": "kalman"}, - {"sim_params.noise_abs": 1, "diff_params.kind": "trend_filtered"}, - { - "sim_params.noise_abs": 2, - "diff_params.diffcls": "SmoothedFiniteDifference", - }, - {"sim_params.noise_abs": 2, "diff_params.kind": "kalman"}, - {"sim_params.noise_abs": 2, "diff_params.kind": "trend_filtered"}, - { - "sim_params.noise_abs": 4, - "diff_params.diffcls": "SmoothedFiniteDifference", - }, - {"sim_params.noise_abs": 4, "diff_params.kind": "kalman"}, - {"sim_params.noise_abs": 4, "diff_params.kind": "trend_filtered"}, - ), - {(0, 2), (3, 2), (0, 3), (3, 3), (0, 4), (3, 4)}, - ), } sim_params = { "test": ND({"n_trajectories": 2}), diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 6e84fc8..941d17d 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -15,6 +15,7 @@ Union, cast, ) +from warnings import warn import matplotlib.pyplot as plt import numpy as np @@ -118,6 +119,7 @@ def _grid_locator_match( param_spec: the criteria for matching exp_params ind_spec: the criteria for matching exp_ind """ + warn("Use find_gridpoints() instead", DeprecationWarning) found_match = False for params_or in param_spec: params_or = {k: _param_normalize(v) for k, v in params_or.items()} @@ -147,10 +149,10 @@ def run( grid_vals: list[Sequence], grid_decisions: list[str], other_params: dict, - skinny_specs: SkinnySpecs, + skinny_specs: Optional[SkinnySpecs] = None, series_params: Optional[SeriesList] = None, metrics: tuple[str, ...] = (), - plot_prefs: _PlotPrefs = _PlotPrefs(True, False, ()), + plot_prefs: _PlotPrefs = _PlotPrefs(), ) -> GridsearchResultDetails: """Run a grid-search wrapper of an experiment. @@ -228,31 +230,41 @@ def run( ) series_searches.append((grid_optima, grid_ind)) + main_metric_ind = metrics.index("main") if "main" in metrics else 0 + results: GridsearchResultDetails = { + "system": group, + "plot_data": [], + "series_data": { + name: data + for data, name in zip( + [list(zip(metrics, argopts)) for metrics, argopts in series_searches], + [ser.name for ser in series_params.series_list], + ) + }, + "metrics": metrics, + "grid_params": grid_params, + "plot_params": [ + param + for decide, param in zip(grid_decisions, grid_params) + if decide == "plot" + ], + "grid_vals": grid_vals, + "main": max( + grid[main_metric_ind].max() + for metrics, _ in series_searches + for grid in metrics + ), + } if plot_prefs: - full_m_inds = _amax_to_full_inds( - plot_prefs.grid_ind_match, [s[1] for s in series_searches] - ) - for int_data in intermediate_data: - logger.debug( - f"Checking whether to save/plot :\n{pformat(int_data['params'])}\n" - f"\tat location {pformat(int_data['pind'])}\n" - f"\tagainst spec: {pformat(plot_prefs.grid_params_match)}\n" - f"\twith allowed locations {pformat(full_m_inds)}" + plot_data = find_gridpoints(plot_prefs.plot_match, intermediate_data, results) + results["plot_data"] = plot_data + for gridpoint in plot_data: + grid_data = gridpoint["data"] + logger.info(f"Plotting: {gridpoint['params']}") + grid_data |= simulate_test_data( + grid_data["model"], grid_data["dt"], grid_data["x_test"] ) - if _grid_locator_match( - int_data["params"], - int_data["pind"], - plot_prefs.grid_params_match, - full_m_inds, - ) and int_data["params"] not in [saved["params"] for saved in plot_data]: - grid_data = int_data["data"] - print("Results for params: ", int_data["params"], flush=True) - grid_data |= simulate_test_data( - grid_data["model"], grid_data["dt"], grid_data["x_test"] - ) - logger.info("Found match, simulating and plotting") - plot_ode_panel(grid_data) - plot_data.append(int_data) + plot_ode_panel(grid_data) # type: ignore if plot_prefs.rel_noise: grid_vals, grid_params = plot_prefs.rel_noise( grid_vals, grid_params, grid_data @@ -284,27 +296,7 @@ def run( fig.suptitle(title) fig.tight_layout() - main_metric_ind = metrics.index("main") if "main" in metrics else 0 - return { - "system": group, - "plot_data": plot_data, - "series_data": { - name: data - for data, name in zip( - [list(zip(metrics, argopts)) for metrics, argopts in series_searches], - [ser.name for ser in series_params.series_list], - ) - }, - "metrics": metrics, - "grid_params": grid_params, - "plot_params": [decide for decide in grid_decisions if decide == "plot"], - "grid_vals": grid_vals, - "main": max( - grid[main_metric_ind].max() - for metrics, _ in series_searches - for grid in metrics - ), - } + return results def plot( @@ -582,29 +574,31 @@ def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> def find_gridpoints( - find: GridLocator, where: GridsearchResultDetails + find: GridLocator, where: list[SavedGridPoint], context: GridsearchResultDetails ) -> list[SavedGridPoint]: """Find results wrapped by gridsearch that match criteria Args: find: the criteria - where: The overall results of the gridsearch + where: The list of saved gridpoints to search + context: The overall data for the gridsearch, describing metrics, grid + setup, and gridsearch results Returns: - A list of the wrapped results, representing points in the gridsearch. + A list of the matching points in the gridsearch. """ results: list[SavedGridPoint] = [] partial_match: set[tuple[int, ...]] = set() if find.metrics is ...: - inds_of_metrics = range(len(where["metrics"])) + inds_of_metrics = range(len(context["metrics"])) else: inds_of_metrics = tuple( - where["metrics"].index(metric) for metric in find.metrics + context["metrics"].index(metric) for metric in find.metrics ) ax_sizes = { - plot_ax: len(where["grid_vals"][where["grid_params"].index(plot_ax)]) - for plot_ax in where["plot_params"] + plot_ax: len(context["grid_vals"][context["grid_params"].index(plot_ax)]) + for plot_ax in context["plot_params"] } if ... in find.keep_axes: keep_axes = _expand_ellipsis_axis(find.keep_axes, ax_sizes) # type: ignore @@ -612,15 +606,15 @@ def find_gridpoints( keep_axes = cast(Collection[tuple[str, tuple[int, ...]]], find.keep_axes) # No deduplication is done! keep_axes = tuple( - (where["grid_params"].index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes + (context["grid_params"].index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes ) - for ser in where["series_data"].values(): + for ser in context["series_data"].values(): for index_of_ax, indexes_in_ax in keep_axes: amax_arr = ser[index_of_ax][1] amax_want = amax_arr[np.ix_(inds_of_metrics, indexes_in_ax)].flatten() partial_match |= {np_to_primitive(el) for el in amax_want} - logger.debug( + logger.info( f"Found {len(partial_match)} gridpoints that match metric-plot_axis criteria" ) @@ -635,7 +629,8 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: else: return _param_normalize(candidate) == criteria - for point in filter(lambda p: p["pind"] in partial_match, where["plot_data"]): + for point in filter(lambda p: p["pind"] in partial_match, where): + logger.debug(f"Checking whether {point['pind']} matches param query") for params_match in params_or: if all( check_values(value, point["params"][param]) @@ -644,7 +639,7 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: results.append(point) break - logger.debug(f"found {len(results)} points that match all GridLocator criteria") + logger.info(f"found {len(results)} points that match all GridLocator criteria") return results diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 153724a..4a55e05 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field -from types import EllipsisType as ellipsis -from typing import Annotated, Any, Callable, Collection, Literal, Sequence +from typing import Annotated, Callable, Literal, Sequence import matplotlib.pyplot as plt import numpy as np @@ -8,6 +7,8 @@ import seaborn as sns from matplotlib.axes import Axes +from .gridsearch.typing import GridLocator + PAL = sns.color_palette("Set1") PLOT_KWS = {"alpha": 0.7, "linewidth": 3} @@ -29,10 +30,7 @@ class _PlotPrefs: plot: bool = True rel_noise: Literal[False] | Callable = False - grid_params_match: Collection[dict[str, Any]] = field(default_factory=lambda: ()) - grid_ind_match: Collection[tuple[int | slice, int]] | ellipsis = field( - default_factory=lambda: ... - ) + plot_match: GridLocator = field(default_factory=lambda: GridLocator()) def __bool__(self): return self.plot diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 7f530d0..6d45ba9 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -192,7 +192,9 @@ def gridsearch_results(): ) def test_find_gridpoints(gridsearch_results, locator): want, full_details = gridsearch_results - results = gridsearch.find_gridpoints(locator, full_details) + results = gridsearch.find_gridpoints( + locator, full_details["plot_data"], full_details + ) assert [want] == results From 4b5b1e87acbde2e1f9de8ac0ccd299b70fc688f9 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 20 Mar 2024 21:45:31 +0000 Subject: [PATCH 31/46] fix: Only care about param index, not metric index, in gridsearch. There has been some ambiguity about what index is stored in gridsearch results. Previously, GridsearchResult contained the indexes of the optimal result in the full series gridsearch array, which included an index of the metric as well as the indexes to uniquely identify parameters at the gridpoint. However, the SavedGridPoint objects obviously only knew their parameter indexes, rather than all carrying around indexes for every metric (which would be identical across every point). Now, early in gridsearch, the metric index is lost, meaning that every GridsearchResult will store indexing tuples for the optimal results that align with parameter indexes. Also, rename np_to_primitive() --- src/gen_experiments/gridsearch/__init__.py | 31 +++++++++------------- src/gen_experiments/gridsearch/typing.py | 2 +- tests/test_gridsearch.py | 6 ++--- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 941d17d..089b7e8 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -65,7 +65,7 @@ def _amax_to_full_inds( if amax_inds is ...: # grab each element from arrays in list of lists of arrays return { - np_to_primitive(el) + void_to_tuple(el) for ar_list in amax_arrays for arr in ar_list for el in arr.flatten() @@ -75,15 +75,15 @@ def _amax_to_full_inds( for ind in amax_inds: if ind is ...: # grab each element from arrays in list of lists of arrays all_inds |= { - np_to_primitive(el) + void_to_tuple(el) for ar_list in amax_arrays for arr in ar_list for el in arr.flatten() } elif isinstance(ind[0], int): - all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))} + all_inds |= {void_to_tuple(cast(np.void, plot_axis_results[ind]))} else: # ind[0] is slice(None) - all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]} + all_inds |= {void_to_tuple(el) for el in plot_axis_results[ind]} return all_inds @@ -383,10 +383,11 @@ def _marginalize_grid_views( grid_decisions: Iterable[str], results: Annotated[NDArray[T], "shape (n_metrics, *n_gridsearch_values)"], max_or_min: Sequence[str], -) -> tuple[list[GridsearchResult[T]], list[GridsearchResult]]: - """Marginalize unnecessary dimensions by taking max across axes. +) -> tuple[list[GridsearchResult[T]], list[GridsearchResult[np.void]]]: + """Marginalize unnecessary dimensions by taking max or min across axes. + + Ignores NaN values and strips the metric index from the argoptima. - Ignores NaN values Args: grid_decisions: list of how to treat each non-metric gridsearch axis. An array of metrics for each "plot" grid decision @@ -396,9 +397,8 @@ def _marginalize_grid_views( max_or_min: either "max" or "min" for each row of results Returns: a list of the metric optima for each plottable grid decision, and - a list of the flattened argoptima. + a list of the flattened argoptima, with metric removed """ - arg_dtype = np.dtype(",".join(results.ndim * "i")) plot_param_inds = [ind for ind, val in enumerate(grid_decisions) if val == "plot"] grid_searches = [] args_maxes = [] @@ -409,13 +409,8 @@ def _marginalize_grid_views( [opt(result, axis=reduce_axes) for opt, result in zip(optfuns, results)] ) sub_arrs = [] - for m_ind, (result, opt) in enumerate(zip(results, max_or_min)): - - def _metric_pad(tp: tuple[int, ...]) -> np.void: - return np.void((m_ind, *tp), dtype=arg_dtype) - - pad_m_ind = np.vectorize(_metric_pad) - arg_max = pad_m_ind(_argopt(result, reduce_axes, opt)) + for result, opt in zip(results, max_or_min): + arg_max = _argopt(result, reduce_axes, opt) sub_arrs.append(arg_max) args_max = np.stack(sub_arrs) @@ -613,7 +608,7 @@ def find_gridpoints( for index_of_ax, indexes_in_ax in keep_axes: amax_arr = ser[index_of_ax][1] amax_want = amax_arr[np.ix_(inds_of_metrics, indexes_in_ax)].flatten() - partial_match |= {np_to_primitive(el) for el in amax_want} + partial_match |= {void_to_tuple(el) for el in amax_want} logger.info( f"Found {len(partial_match)} gridpoints that match metric-plot_axis criteria" ) @@ -664,6 +659,6 @@ def _expand_ellipsis_axis( raise TypeError("Keep_axis does not have an ellipsis or is not a 2-tuple") -def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: +def void_to_tuple(tuple_like: np.void) -> tuple[int, ...]: """Turn a void that represents a tuple of ints into a tuple of ints""" return tuple(int(el) for el in cast(Iterable, tuple_like)) diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index 7e890f7..4b766d2 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -78,7 +78,7 @@ class GridLocator: Annotated[GridsearchResult[np.void], "arg_opts"], ] ], - "len=n_grid_axes", + "len=n_plot_axes", ] ExpResult = dict[str, Any] diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 6d45ba9..992ec16 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -78,10 +78,10 @@ def test_marginalize_grid_views(): for result, expected in zip(res_val, expected_val): np.testing.assert_array_equal(result, expected) - ts = "i,i,i,i" + ts = "i,i,i" expected_ind = [ - np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 0, 0, 0), (1, 1, 1, 0)]], ts), - np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 1, 0), (1, 0, 0, 1)]], ts), + np.array([[(0, 0, 0), (1, 1, 1)], [(0, 0, 0), (1, 1, 0)]], ts), + np.array([[(0, 0, 0), (1, 1, 1)], [(1, 1, 0), (0, 0, 1)]], ts), ] for result, expected in zip(res_ind, expected_ind): np.testing.assert_array_equal(result, expected) From dc29e6ef4d2f02d2c4e210075224b58ac86cc38e Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 20 Mar 2024 22:01:36 +0000 Subject: [PATCH 32/46] fix: Handle grid locators with missing keys --- src/gen_experiments/gridsearch/__init__.py | 2 +- tests/test_gridsearch.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 089b7e8..9477fcd 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -628,7 +628,7 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: logger.debug(f"Checking whether {point['pind']} matches param query") for params_match in params_or: if all( - check_values(value, point["params"][param]) + param in point["params"] and check_values(value, point["params"][param]) for param, value in params_match.items() ): results.append(point) diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 992ec16..ce9c5bc 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -187,8 +187,9 @@ def gridsearch_results(): gridsearch.GridLocator( ..., (..., ...), [{"diff_params.alpha": 0.1}, {"diff_params.alpha": 0.3}] ), + gridsearch.GridLocator(params_or=[{"diff_params.alpha": 0.1}, {"foo": 0}]), ), - ids=("exact", "object", "callable", "by_axis", "or"), + ids=("exact", "object", "callable", "by_axis", "or", "missingkey"), ) def test_find_gridpoints(gridsearch_results, locator): want, full_details = gridsearch_results From 5c2c93df8fc8d94d1b74548613292ccf267a467d Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:55:13 +0000 Subject: [PATCH 33/46] types: cleaner typing --- src/gen_experiments/__init__.py | 3 ++- src/gen_experiments/config.py | 7 +++++-- src/gen_experiments/data.py | 2 +- src/gen_experiments/gridsearch/__init__.py | 1 + src/gen_experiments/gridsearch/typing.py | 6 ++++-- src/gen_experiments/odes.py | 2 +- src/gen_experiments/plotting.py | 4 ++-- tests/test_gridsearch.py | 4 ++-- 8 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index cd4a3dc..b9708aa 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -4,6 +4,7 @@ from typing import Any import numpy as np +from mitosis import Experiment from numpy.typing import NDArray from pysindy import BaseDifferentiation, FiniteDifference, SINDy @@ -67,7 +68,7 @@ def run( return metrics -experiments = { +experiments: dict[str, tuple[Experiment, str | None]] = { "sho": (odes, "sho"), "lorenz": (odes, "lorenz"), "lorenz_2d": (odes, "lorenz_2d"), diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index b90ece4..0a50f1c 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -3,6 +3,7 @@ import numpy as np import pysindy as ps +from numpy.typing import NDArray from gen_experiments.data import _signal_avg_power from gen_experiments.gridsearch.typing import ( @@ -24,8 +25,10 @@ def ND(d: dict[T, U]) -> NestedDict[T, U]: def _convert_abs_rel_noise( - grid_vals: list, grid_params: list, recent_results: FullSINDyTrialData -): + grid_vals: list[NDArray[np.floating]], + grid_params: list[str], + recent_results: FullSINDyTrialData, +) -> tuple[list[NDArray[np.floating]], list[str]]: """Convert abs_noise grid_vals to rel_noise""" signal = np.stack(recent_results["x_true"], axis=-1) signal_power = _signal_avg_power(signal) diff --git a/src/gen_experiments/data.py b/src/gen_experiments/data.py index 0fc9760..6c95c73 100644 --- a/src/gen_experiments/data.py +++ b/src/gen_experiments/data.py @@ -26,7 +26,7 @@ def gen_data( nonnegative: bool = False, dt: float = 0.01, t_end: float = 10, -) -> tuple[float, Float1D, Float2D, Float2D, Float2D, Float2D]: +) -> tuple[float, Float1D, list[Float2D], list[Float2D], list[Float2D], list[Float2D]]: """Generate random training and test data Note that test data has no noise. diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 9477fcd..762beec 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -231,6 +231,7 @@ def run( series_searches.append((grid_optima, grid_ind)) main_metric_ind = metrics.index("main") if "main" in metrics else 0 + results: GridsearchResultDetails = { "system": group, "plot_data": [], diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index 4b766d2..4c88da7 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -6,6 +6,7 @@ Any, Callable, Collection, + Generic, Optional, Sequence, TypedDict, @@ -82,9 +83,10 @@ class GridLocator: ] ExpResult = dict[str, Any] +ExpResultVar = TypeVar("ExpResultVar", bound=ExpResult) -class SavedGridPoint(TypedDict): +class SavedGridPoint(TypedDict, Generic[ExpResultVar]): """The results at a point in the gridsearch. Args: @@ -95,7 +97,7 @@ class SavedGridPoint(TypedDict): params: dict pind: tuple[int, ...] - data: ExpResult + data: ExpResultVar class GridsearchResultDetails(TypedDict): diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index e45277a..2ed043f 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -154,7 +154,7 @@ def forcing(t, x): def run( - seed: float, + seed: int, group: str, sim_params: dict, diff_params: dict, diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 4a55e05..5c8ac8d 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Annotated, Callable, Literal, Sequence +from typing import Annotated, Any, Callable, Literal, Sequence import matplotlib.pyplot as plt import numpy as np @@ -29,7 +29,7 @@ class _PlotPrefs: """ plot: bool = True - rel_noise: Literal[False] | Callable = False + rel_noise: Literal[False] | Callable[..., tuple[list[Any], list[str]]] = False plot_match: GridLocator = field(default_factory=lambda: GridLocator()) def __bool__(self): diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index ce9c5bc..4dc4770 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -163,9 +163,9 @@ def gridsearch_results(): "plot_data": [want, dont_want], "series_data": {"foo": max_amax}, "metrics": ("mse", "mae"), - "grid_params": ["sim_params.t_end", "sim_params.noise"], + "grid_params": ["sim_params.t_end", "bar", "sim_params.noise"], "plot_params": ["sim_params.t_end", "sim_params.noise"], - "grid_vals": [[1, 2], [5, 6]], + "grid_vals": [[1, 2], [7, 8], [5, 6]], "main": 1, } return want, full_details From 72d0512deac03dfcc9d6c4c6395c098555510089 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:27:03 +0000 Subject: [PATCH 34/46] fix(gridsearch): Make plot_params more explicit --- src/gen_experiments/gridsearch/__init__.py | 46 +++++++++++++--------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 762beec..d129cff 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -35,6 +35,7 @@ GridLocator, GridsearchResult, GridsearchResultDetails, + KeepAxisSpec, OtherSliceDef, SavedGridPoint, SeriesDef, @@ -244,11 +245,7 @@ def run( }, "metrics": metrics, "grid_params": grid_params, - "plot_params": [ - param - for decide, param in zip(grid_decisions, grid_params) - if decide == "plot" - ], + "plot_params": [], "grid_vals": grid_vals, "main": max( grid[main_metric_ind].max() @@ -268,8 +265,16 @@ def run( plot_ode_panel(grid_data) # type: ignore if plot_prefs.rel_noise: grid_vals, grid_params = plot_prefs.rel_noise( - grid_vals, grid_params, grid_data + grid_vals, grid_params, intermediate_data ) + results["grid_vals"] = grid_vals + plot_params = [ + param + for decide, param in zip(grid_decisions, grid_params) + if decide == "plot" + ] + results["plot_params"] = plot_params + fig, subplots = plt.subplots( n_metrics, n_plotparams, @@ -278,15 +283,15 @@ def run( squeeze=False, figsize=(n_plotparams * 3, 0.5 + n_metrics * 2.25), ) - for series_data, series_name in zip( + for series_search, series_name in zip( series_searches, (ser.name for ser in series_params.series_list) ): plot( subplots, metrics, - grid_params, + plot_params, grid_vals, - series_data[0], + series_search[0], series_name, legends, ) @@ -303,7 +308,7 @@ def run( def plot( subplots: NDArray[Annotated[np.void, "Axes"]], metrics: Sequence[str], - grid_params: Sequence[str], + plot_params: Sequence[str], grid_vals: Sequence[Sequence[float] | np.ndarray], grid_searches: Sequence[GridsearchResult], name: str, @@ -313,7 +318,7 @@ def plot( raise ValueError("Nothing to plot") for m_ind_row, m_name in enumerate(metrics): for col, (param_name, x_ticks, param_search) in enumerate( - zip(grid_params, grid_vals, grid_searches) + zip(plot_params, grid_vals, grid_searches) ): ax = cast(Axes, subplots[m_ind_row, col]) ax.plot(x_ticks, param_search[m_ind_row], label=name) @@ -586,6 +591,7 @@ def find_gridpoints( results: list[SavedGridPoint] = [] partial_match: set[tuple[int, ...]] = set() + inds_of_metrics: Sequence[int] if find.metrics is ...: inds_of_metrics = range(len(context["metrics"])) else: @@ -596,14 +602,8 @@ def find_gridpoints( plot_ax: len(context["grid_vals"][context["grid_params"].index(plot_ax)]) for plot_ax in context["plot_params"] } - if ... in find.keep_axes: - keep_axes = _expand_ellipsis_axis(find.keep_axes, ax_sizes) # type: ignore - else: - keep_axes = cast(Collection[tuple[str, tuple[int, ...]]], find.keep_axes) # No deduplication is done! - keep_axes = tuple( - (context["grid_params"].index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes - ) + keep_axes = _normalize_keep_axes(find.keep_axes, ax_sizes, context["plot_params"]) for ser in context["series_data"].values(): for index_of_ax, indexes_in_ax in keep_axes: @@ -639,6 +639,16 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: return results +def _normalize_keep_axes( + keep_axes: KeepAxisSpec, ax_sizes: dict[str, int], plot_params: list[str] +) -> tuple[tuple[int, tuple[int, ...]], ...]: + if ... in keep_axes: + keep_axes = _expand_ellipsis_axis(keep_axes, ax_sizes) # type: ignore + else: + keep_axes = cast(Collection[tuple[str, tuple[int, ...]]], keep_axes) + return tuple((plot_params.index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes) + + def _expand_ellipsis_axis( keep_axis: Union[ tuple[ellipsis, ellipsis], From 17b37d980ad072347d3bf0bbc71ff8fcaf5087cf Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:58:44 +0000 Subject: [PATCH 35/46] fix (gridsearch): Separate grid params/vals, scan_grid, and plot_grid This change disambiguates the naming and values of axes between running the gridsearch, deciding which axes should be optimized vs scanned, and which should be recalculated in plotting. This is to allow better relative noise conversion, eventually: _PlotPrefs.rel_noise has been disallowed, as it does not behave correctly. --- src/gen_experiments/config.py | 20 ++++++------ src/gen_experiments/gridsearch/__init__.py | 36 +++++++++------------- src/gen_experiments/gridsearch/typing.py | 5 +-- src/gen_experiments/plotting.py | 2 +- tests/test_gridsearch.py | 3 +- 5 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 0a50f1c..db8c0b1 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -1,5 +1,5 @@ -from collections.abc import Iterable -from typing import TypeVar +from collections.abc import Iterable, Sequence +from typing import TypeVar, cast import numpy as np import pysindy as ps @@ -16,7 +16,7 @@ from gen_experiments.typing import NestedDict from gen_experiments.utils import FullSINDyTrialData -T = TypeVar("T") +T = TypeVar("T", bound=str) U = TypeVar("U") @@ -25,17 +25,17 @@ def ND(d: dict[T, U]) -> NestedDict[T, U]: def _convert_abs_rel_noise( - grid_vals: list[NDArray[np.floating]], - grid_params: list[str], + scan_grid: dict[str, NDArray[np.floating]], recent_results: FullSINDyTrialData, -) -> tuple[list[NDArray[np.floating]], list[str]]: +) -> dict[str, Sequence[np.floating]]: """Convert abs_noise grid_vals to rel_noise""" signal = np.stack(recent_results["x_true"], axis=-1) signal_power = _signal_avg_power(signal) - ind = grid_params.index("sim_params.noise_abs") - grid_vals[ind] = grid_vals[ind] / signal_power - grid_params[ind] = "sim_params.noise_rel" - return grid_vals, grid_params + plot_grid = scan_grid.copy() + new_vals = plot_grid["sim_params.noise_abs"] / signal_power + plot_grid["sim_params.noise_rel"] = new_vals + plot_grid.pop("sim_params.noise_abs") + return cast(dict[str, Sequence[np.floating]], plot_grid) # To allow pickling diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index d129cff..6ea0902 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -232,7 +232,9 @@ def run( series_searches.append((grid_optima, grid_ind)) main_metric_ind = metrics.index("main") if "main" in metrics else 0 - + scan_grid = { + p: v for p, d, v in zip(grid_params, grid_decisions, grid_vals) if d == "plot" + } results: GridsearchResultDetails = { "system": group, "plot_data": [], @@ -245,8 +247,9 @@ def run( }, "metrics": metrics, "grid_params": grid_params, - "plot_params": [], "grid_vals": grid_vals, + "scan_grid": scan_grid, + "plot_grid": {}, "main": max( grid[main_metric_ind].max() for metrics, _ in series_searches @@ -264,16 +267,9 @@ def run( ) plot_ode_panel(grid_data) # type: ignore if plot_prefs.rel_noise: - grid_vals, grid_params = plot_prefs.rel_noise( - grid_vals, grid_params, intermediate_data - ) - results["grid_vals"] = grid_vals - plot_params = [ - param - for decide, param in zip(grid_decisions, grid_params) - if decide == "plot" - ] - results["plot_params"] = plot_params + raise ValueError("_PlotPrefs.rel_noise is not correctly implemented.") + else: + results["plot_grid"] = scan_grid fig, subplots = plt.subplots( n_metrics, @@ -289,8 +285,8 @@ def run( plot( subplots, metrics, - plot_params, - grid_vals, + cast(Sequence[str], results["plot_grid"].keys()), + cast(Sequence[Sequence], results["plot_grid"].values()), series_search[0], series_name, legends, @@ -598,12 +594,8 @@ def find_gridpoints( inds_of_metrics = tuple( context["metrics"].index(metric) for metric in find.metrics ) - ax_sizes = { - plot_ax: len(context["grid_vals"][context["grid_params"].index(plot_ax)]) - for plot_ax in context["plot_params"] - } # No deduplication is done! - keep_axes = _normalize_keep_axes(find.keep_axes, ax_sizes, context["plot_params"]) + keep_axes = _normalize_keep_axes(find.keep_axes, context["scan_grid"]) for ser in context["series_data"].values(): for index_of_ax, indexes_in_ax in keep_axes: @@ -640,13 +632,15 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: def _normalize_keep_axes( - keep_axes: KeepAxisSpec, ax_sizes: dict[str, int], plot_params: list[str] + keep_axes: KeepAxisSpec, scan_grid: dict[str, Sequence[Any]] ) -> tuple[tuple[int, tuple[int, ...]], ...]: + ax_sizes = {ax_name: len(vals) for ax_name, vals in scan_grid.items()} if ... in keep_axes: keep_axes = _expand_ellipsis_axis(keep_axes, ax_sizes) # type: ignore else: keep_axes = cast(Collection[tuple[str, tuple[int, ...]]], keep_axes) - return tuple((plot_params.index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes) + scan_axes = tuple(ax_sizes.keys()) + return tuple((scan_axes.index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes) def _expand_ellipsis_axis( diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index 4c88da7..90b000a 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -106,8 +106,9 @@ class GridsearchResultDetails(TypedDict): series_data: dict[str, SeriesData] metrics: tuple[str, ...] grid_params: list[str] - plot_params: list[str] - grid_vals: list[Sequence] + grid_vals: list[Sequence[Any]] + scan_grid: dict[str, Sequence[Any]] + plot_grid: dict[str, Sequence[Any]] main: float diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 5c8ac8d..2de5eeb 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -29,7 +29,7 @@ class _PlotPrefs: """ plot: bool = True - rel_noise: Literal[False] | Callable[..., tuple[list[Any], list[str]]] = False + rel_noise: Literal[False] | Callable[..., dict[str, Sequence[Any]]] = False plot_match: GridLocator = field(default_factory=lambda: GridLocator()) def __bool__(self): diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 4dc4770..2733711 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -163,8 +163,9 @@ def gridsearch_results(): "plot_data": [want, dont_want], "series_data": {"foo": max_amax}, "metrics": ("mse", "mae"), + "scan_grid": {"sim_params.t_end": [1, 2], "sim_params.noise": [5, 6]}, + "plot_grid": {}, "grid_params": ["sim_params.t_end", "bar", "sim_params.noise"], - "plot_params": ["sim_params.t_end", "sim_params.noise"], "grid_vals": [[1, 2], [7, 8], [5, 6]], "main": 1, } From dea553bdd94902fa92ed208b063ddb2a9b362175 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 27 Mar 2024 22:20:10 +0000 Subject: [PATCH 36/46] fix (gridsearch): Use same random seed in every trial --- src/gen_experiments/gridsearch/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 6ea0902..63c87d3 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -211,14 +211,12 @@ def run( gridpoint_selector = _ndindex_skinny( full_results_shape[1:], ind_skinny, where_others ) - rng = np.random.default_rng(seed) for ind_counter, ind in enumerate(gridpoint_selector): print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r") - new_seed = rng.integers(1000) for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): curr_other_params[key] = val_list[axis_ind] curr_results, grid_data = base_ex.run( - new_seed, **curr_other_params, display=False, return_all=True + seed, **curr_other_params, display=False, return_all=True ) intermediate_data.append( {"params": curr_other_params.flatten(), "pind": ind, "data": grid_data} From 49cee792502a929ba05fb3e33b676aad86f0d465 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 28 Mar 2024 10:41:00 +0000 Subject: [PATCH 37/46] type: Remove genericness from ExpResult Python 3.10 does not allow generic TypedDicts --- .github/workflows/main.yaml | 2 +- src/gen_experiments/__init__.py | 3 +-- src/gen_experiments/gridsearch/typing.py | 6 ++---- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index d7c9b33..24a56a4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -36,7 +36,7 @@ jobs: pip install -e .[dev] - name: run mypy run: | - mypy -v + mypy Tests: diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index b9708aa..c7174f1 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -4,7 +4,6 @@ from typing import Any import numpy as np -from mitosis import Experiment from numpy.typing import NDArray from pysindy import BaseDifferentiation, FiniteDifference, SINDy @@ -68,7 +67,7 @@ def run( return metrics -experiments: dict[str, tuple[Experiment, str | None]] = { +experiments: dict[str, tuple[Any, str | None]] = { "sho": (odes, "sho"), "lorenz": (odes, "lorenz"), "lorenz_2d": (odes, "lorenz_2d"), diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py index 90b000a..20e551e 100644 --- a/src/gen_experiments/gridsearch/typing.py +++ b/src/gen_experiments/gridsearch/typing.py @@ -6,7 +6,6 @@ Any, Callable, Collection, - Generic, Optional, Sequence, TypedDict, @@ -83,10 +82,9 @@ class GridLocator: ] ExpResult = dict[str, Any] -ExpResultVar = TypeVar("ExpResultVar", bound=ExpResult) -class SavedGridPoint(TypedDict, Generic[ExpResultVar]): +class SavedGridPoint(TypedDict): """The results at a point in the gridsearch. Args: @@ -97,7 +95,7 @@ class SavedGridPoint(TypedDict, Generic[ExpResultVar]): params: dict pind: tuple[int, ...] - data: ExpResultVar + data: ExpResult class GridsearchResultDetails(TypedDict): From ad211f06cf7799146c8a4ef92814dd222b3a12a3 Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Sat, 6 Apr 2024 13:45:17 -0700 Subject: [PATCH 38/46] ENH: Set unbias=True for MIOSR optimizer --- src/gen_experiments/config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 063b021..f87602a 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -160,29 +160,29 @@ def addn(x): } opt_params = { "test": ND({"optcls": "STLSQ"}), - "miosr": ND({"optcls": "MIOSR"}), + "miosr": ND({"optcls": "MIOSR", "unbias": True}), "enslsq": ND( {"optcls": "ensemble", "opt": ps.STLSQ(), "bagging": True, "n_models": 20} ), "ensmio-ho-vdp-lv-duff": ND({ "optcls": "ensemble", - "opt": ps.MIOSR(target_sparsity=4), + "opt": ps.MIOSR(target_sparsity=4, unbias=True), "bagging": True, "n_models": 20, }), "ensmio-hopf": ND({ "optcls": "ensemble", - "opt": ps.MIOSR(target_sparsity=8), + "opt": ps.MIOSR(target_sparsity=8, unbias=True), "bagging": True, "n_models": 20, }), "ensmio-lorenz-ross": ND({ "optcls": "ensemble", - "opt": ps.MIOSR(target_sparsity=7), + "opt": ps.MIOSR(target_sparsity=7, unbias=True), "bagging": True, "n_models": 20, }), - "mio-lorenz-ross": ND({"optcls": "MIOSR", "target_sparsity": 7}), + "mio-lorenz-ross": ND({"optcls": "MIOSR", "target_sparsity": 7, "unbias": True}), } # Grid search parameters @@ -323,7 +323,7 @@ def addn(x): ), "auto-kalman3": SeriesDef( "Auto Kalman", - diff_params["kalman-auto"], + diff_params["kalman"], ["diff_params.alpha"], [(None,)], ), From fe6cf2713d42f61e694f01f532eab10a577b3399 Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Sat, 6 Apr 2024 13:51:18 -0700 Subject: [PATCH 39/46] ENH: modified plot_prefs, sim_params.rel_noise and skinny_specs --- src/gen_experiments/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index f87602a..d823045 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -54,11 +54,11 @@ def addn(x): plot_prefs = { "test": _PlotPrefs(), "test-absrel": _PlotPrefs( - True, _convert_abs_rel_noise, GridLocator(..., {("sim_params.noise_abs", (1,))}) + True, False, GridLocator(..., {("sim_params.noise_abs", (1,))}) ), "test-absrel2": _PlotPrefs( True, - _convert_abs_rel_noise, + False, GridLocator( ..., (..., ...), @@ -74,7 +74,7 @@ def addn(x): ), "absrel-newloc": _PlotPrefs( True, - _convert_abs_rel_noise, + False, GridLocator( ["coeff_mse", "coeff_f1"], (..., (2, 3, 4)), @@ -288,7 +288,7 @@ def addn(x): "lorenzk": [[1, 9, 27], [0.1, 0.8], np.logspace(-6, -1, 4)], "lorenz1": [[1, 3, 9, 27], [0.01, 0.1, 1]], "duration-absnoise": [[0.5, 1, 2, 4, 8, 16], [0.1, 0.5, 1, 2, 4, 8]], - "rel_noise": [[0.25, 1, 4, 16], [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]], + "rel_noise": [[0.5, 1, 2, 4, 8, 16], [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]], } grid_decisions = { "test": ["plot"], @@ -424,5 +424,5 @@ def addn(x): ("sim_params.t_end", "sim_params.noise_abs", "diff_params.meas_var"), ((1, 1), (-1, identity), (-1, identity)), ), - "duration-noise": (("sim_params.t_end", "sim_params.noise_abs"), ((1,), (-1,))), + "duration-noise": (("sim_params.t_end", "sim_params.noise_rel"), ((1,), (-1,))), } From 676cede456f3c40427f1cd7b456b1a0436c30165 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 21 Apr 2024 20:29:16 -0700 Subject: [PATCH 40/46] feat (gridsearch): Allow narrower specification in find_gridpoints() Added several arguments to find_gridpoints(). Because matching searches both in argopt arrays across series, then checks points for matching parameters, it was possible to find a points that matched the wrong series, but correct argopt requirements, which then mathed the parameters. --- src/gen_experiments/gridsearch/__init__.py | 37 ++++++++++++++-------- tests/test_gridsearch.py | 6 +++- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index 63c87d3..b65901b 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable +from collections.abc import Collection, Iterable from copy import copy from functools import partial from logging import getLogger @@ -24,11 +24,11 @@ from scipy.stats import kstest import gen_experiments -from gen_experiments import config -from gen_experiments.odes import plot_ode_panel -from gen_experiments.plotting import _PlotPrefs -from gen_experiments.typing import FloatND, NestedDict -from gen_experiments.utils import simulate_test_data +from .. import config +from ..odes import plot_ode_panel +from ..plotting import _PlotPrefs +from ..typing import FloatND, NestedDict +from ..utils import simulate_test_data from .typing import ( ExpResult, @@ -38,6 +38,7 @@ KeepAxisSpec, OtherSliceDef, SavedGridPoint, + SeriesData, SeriesDef, SeriesList, SkinnySpecs, @@ -255,7 +256,13 @@ def run( ), } if plot_prefs: - plot_data = find_gridpoints(plot_prefs.plot_match, intermediate_data, results) + plot_data = find_gridpoints( + plot_prefs.plot_match, + intermediate_data, + results["series_data"].values(), + results["metrics"], + results["scan_grid"] + ) results["plot_data"] = plot_data for gridpoint in plot_data: grid_data = gridpoint["data"] @@ -569,7 +576,11 @@ def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> def find_gridpoints( - find: GridLocator, where: list[SavedGridPoint], context: GridsearchResultDetails + find: GridLocator, + where: list[SavedGridPoint], + argopt_arrs: Collection[SeriesData], + argopt_metrics: Sequence[str], + argopt_axes: dict[str, Sequence[object]], ) -> list[SavedGridPoint]: """Find results wrapped by gridsearch that match criteria @@ -582,20 +593,20 @@ def find_gridpoints( Returns: A list of the matching points in the gridsearch. """ - results: list[SavedGridPoint] = [] partial_match: set[tuple[int, ...]] = set() inds_of_metrics: Sequence[int] if find.metrics is ...: - inds_of_metrics = range(len(context["metrics"])) + inds_of_metrics = range(len(argopt_metrics)) else: inds_of_metrics = tuple( - context["metrics"].index(metric) for metric in find.metrics + argopt_metrics.index(metric) for metric in find.metrics ) # No deduplication is done! - keep_axes = _normalize_keep_axes(find.keep_axes, context["scan_grid"]) + keep_axes = _normalize_keep_axes(find.keep_axes, argopt_axes) - for ser in context["series_data"].values(): + ser: list[tuple[GridsearchResult[np.floating], GridsearchResult[np.void]]] + for ser in argopt_arrs: for index_of_ax, indexes_in_ax in keep_axes: amax_arr = ser[index_of_ax][1] amax_want = amax_arr[np.ix_(inds_of_metrics, indexes_in_ax)].flatten() diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 2733711..7e89647 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -195,7 +195,11 @@ def gridsearch_results(): def test_find_gridpoints(gridsearch_results, locator): want, full_details = gridsearch_results results = gridsearch.find_gridpoints( - locator, full_details["plot_data"], full_details + locator, + full_details["plot_data"], + full_details["series_data"].values(), + full_details["metrics"], + full_details["scan_grid"] ) assert [want] == results From 64d7e103c115e82d5b233012432b260c03696226 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 22 Apr 2024 20:46:59 -0700 Subject: [PATCH 41/46] fix (config): Plot prefs with callable params_or should accept float --- src/gen_experiments/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index d823045..1fb7ea7 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -82,7 +82,7 @@ def addn(x): {"diff_params.kind": "kalman", "diff_params.alpha": None}, { "diff_params.kind": "kalman", - "diff_params.alpha": lambda a: isinstance(a, int), + "diff_params.alpha": lambda a: isinstance(a, float | int), }, {"diff_params.kind": "trend_filtered"}, {"diff_params.diffcls": "SmoothedFiniteDifference"}, From 139470ee4b1374a06dd0a56f2d631d0ce15c606f Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:59:47 -0700 Subject: [PATCH 42/46] fix(gridsearch): link finding gridpoints to series --- .gitignore | 1 + src/gen_experiments/gridsearch/__init__.py | 51 ++++++++++++---------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index a867520..6ff01f9 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ gen_experiments/trials/ scratch/ *.png debugme*.py +trials/* # IDE files .vscode diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index b65901b..d66a324 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -4,17 +4,7 @@ from logging import getLogger from pprint import pformat from types import EllipsisType as ellipsis -from typing import ( - Annotated, - Any, - Callable, - Collection, - Optional, - Sequence, - TypeVar, - Union, - cast, -) +from typing import Annotated, Any, Callable, Optional, Sequence, TypeVar, Union, cast from warnings import warn import matplotlib.pyplot as plt @@ -24,12 +14,12 @@ from scipy.stats import kstest import gen_experiments + from .. import config from ..odes import plot_ode_panel from ..plotting import _PlotPrefs from ..typing import FloatND, NestedDict from ..utils import simulate_test_data - from .typing import ( ExpResult, GridLocator, @@ -178,6 +168,7 @@ def run( docstring for _ndindex_skinny). By default, all plot axes are made skinny with respect to each other. """ + logger.info(f"Beginning gridsearch of system: {group}") other_params = NestedDict(**other_params) base_ex, base_group = gen_experiments.experiments[group] if series_params is None: @@ -213,7 +204,10 @@ def run( full_results_shape[1:], ind_skinny, where_others ) for ind_counter, ind in enumerate(gridpoint_selector): - print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r") + logger.info( + f"Calculating series {s_counter} ({series_data.name}), " + f"gridpoint {ind_counter} ({ind})" + ) for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): curr_other_params[key] = val_list[axis_ind] curr_results, grid_data = base_ex.run( @@ -256,13 +250,26 @@ def run( ), } if plot_prefs: - plot_data = find_gridpoints( - plot_prefs.plot_match, - intermediate_data, - results["series_data"].values(), - results["metrics"], - results["scan_grid"] - ) + plot_data = [] + # todo - improve how plot_prefs.plot_match interacts with series + # This is a horrible hack, assuming a params_or for each series, ino + for series_data, params in zip( + series_params.series_list, + list(plot_prefs.plot_match.params_or), + strict=True, + ): + key = series_data.name + logger.info(f"Searching for matching points in series: {key}") + locator = GridLocator( + plot_prefs.plot_match.metrics, plot_prefs.plot_match.keep_axes, [params] + ) + plot_data += find_gridpoints( + locator, + intermediate_data, + [results["series_data"][key]], + results["metrics"], + results["scan_grid"], + ) results["plot_data"] = plot_data for gridpoint in plot_data: grid_data = gridpoint["data"] @@ -599,9 +606,7 @@ def find_gridpoints( if find.metrics is ...: inds_of_metrics = range(len(argopt_metrics)) else: - inds_of_metrics = tuple( - argopt_metrics.index(metric) for metric in find.metrics - ) + inds_of_metrics = tuple(argopt_metrics.index(metric) for metric in find.metrics) # No deduplication is done! keep_axes = _normalize_keep_axes(find.keep_axes, argopt_axes) From 77753b828bea7b44bcd5b5ba58de56f43f7fa91b Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 24 Apr 2024 10:53:52 -0700 Subject: [PATCH 43/46] feat: Add logging for durations of work --- pyproject.toml | 2 +- src/gen_experiments/gridsearch/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c223195..3f31a19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "mitosis >=0.4.0rc2", "derivative @ git+https://github.com/Jacob-Stevens-Haas/derivative@hyperparams", "pysindy[cvxpy,miosr] @ git+https://github.com/dynamicslab/pysindy@master", - "kalman @ git+https://github.com/Jacob-Stevens-Haas/kalman@0.1.0", + "kalman @ git+https://github.com/Jacob-Stevens-Haas/kalman", "auto_ks @ git+https://github.com/cvxgrp/auto_ks.git@e60bcc6", "pytest >= 6.0.0", "pytest-cov", diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index d66a324..e5a2f1c 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -3,6 +3,7 @@ from functools import partial from logging import getLogger from pprint import pformat +from time import process_time from types import EllipsisType as ellipsis from typing import Annotated, Any, Callable, Optional, Sequence, TypeVar, Union, cast from warnings import warn @@ -208,6 +209,7 @@ def run( f"Calculating series {s_counter} ({series_data.name}), " f"gridpoint {ind_counter} ({ind})" ) + start = process_time() for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): curr_other_params[key] = val_list[axis_ind] curr_results, grid_data = base_ex.run( @@ -219,6 +221,7 @@ def run( full_results[(slice(None), *ind)] = [ curr_results[metric] for metric in metrics ] + logger.info(f"Last calculation: {process_time() - start:.2f} sec.") grid_optima, grid_ind = _marginalize_grid_views( new_grid_decisions, full_results, metric_ordering ) @@ -260,6 +263,7 @@ def run( ): key = series_data.name logger.info(f"Searching for matching points in series: {key}") + start = process_time() locator = GridLocator( plot_prefs.plot_match.metrics, plot_prefs.plot_match.keep_axes, [params] ) @@ -270,14 +274,17 @@ def run( results["metrics"], results["scan_grid"], ) + logger.info(f"Searching took {process_time() - start:.2f} sec") results["plot_data"] = plot_data for gridpoint in plot_data: grid_data = gridpoint["data"] logger.info(f"Plotting: {gridpoint['params']}") + start = process_time() grid_data |= simulate_test_data( grid_data["model"], grid_data["dt"], grid_data["x_test"] ) plot_ode_panel(grid_data) # type: ignore + logger.info(f"Sim/Plot took {process_time() - start:.2f} sec") if plot_prefs.rel_noise: raise ValueError("_PlotPrefs.rel_noise is not correctly implemented.") else: From bd79e8fdf533ad7a1a2a6f01305bf39dc13b1b10 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 26 Apr 2024 14:18:05 -0700 Subject: [PATCH 44/46] fix (gridsearch): Make misalignment between series and plot_prefs a warning --- src/gen_experiments/gridsearch/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index e5a2f1c..3095973 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -256,10 +256,18 @@ def run( plot_data = [] # todo - improve how plot_prefs.plot_match interacts with series # This is a horrible hack, assuming a params_or for each series, ino + if len(series_params.series_list) != len(plot_prefs.plot_match.params_or): + msg = ( + "Trying to plot a subset of points tends to require the same" + "number of matchable parameter lists as series, lined up 1:1." + "You have a different number of each." + ) + # TODO: write a warn_external function in mitosis for this: + warn(msg) + logger.warning(msg) for series_data, params in zip( series_params.series_list, list(plot_prefs.plot_match.params_or), - strict=True, ): key = series_data.name logger.info(f"Searching for matching points in series: {key}") From 31383318ae86fe3c662d683e0d4a9caf9a2def51 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 28 Apr 2024 11:10:32 -0700 Subject: [PATCH 45/46] CLN: black --- tests/test_gridsearch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py index 7e89647..1586d4f 100644 --- a/tests/test_gridsearch.py +++ b/tests/test_gridsearch.py @@ -199,7 +199,7 @@ def test_find_gridpoints(gridsearch_results, locator): full_details["plot_data"], full_details["series_data"].values(), full_details["metrics"], - full_details["scan_grid"] + full_details["scan_grid"], ) assert [want] == results From cbaa0592f3f347f99bb47bbed58d1dedc1d76b18 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sat, 11 May 2024 13:18:48 -0700 Subject: [PATCH 46/46] BLD: Pin dependencies --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3f31a19..0be17c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,9 @@ classifiers = [ # dependencies dependencies = [ "mitosis >=0.4.0rc2", - "derivative @ git+https://github.com/Jacob-Stevens-Haas/derivative@hyperparams", - "pysindy[cvxpy,miosr] @ git+https://github.com/dynamicslab/pysindy@master", - "kalman @ git+https://github.com/Jacob-Stevens-Haas/kalman", + "derivative @ git+https://github.com/andgoldschmidt/derivative@f0d566d", + "pysindy[cvxpy,miosr] @ git+https://github.com/dynamicslab/pysindy@a43e217", + "kalman @ git+https://github.com/Jacob-Stevens-Haas/kalman@0.1.0", "auto_ks @ git+https://github.com/cvxgrp/auto_ks.git@e60bcc6", "pytest >= 6.0.0", "pytest-cov",