diff --git a/pyproject.toml b/pyproject.toml index 81edcf7..c223195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,13 @@ addopts = '-m "not slow"' markers = ["slow"] [tool.mypy] -files = ["src/gen_experiments/__init__.py", "src/gen_experiments/utils.py"] +files = [ + "src/gen_experiments/__init__.py", + "src/gen_experiments/utils.py", + "src/gen_experiments/gridsearch/typing.py", + "tests/test_all.py", + "tests/test_gridsearch.py", +] [[tool.mypy.overrides]] module="auto_ks.*" diff --git a/src/gen_experiments/__init__.py b/src/gen_experiments/__init__.py index 978a8e2..cd4a3dc 100644 --- a/src/gen_experiments/__init__.py +++ b/src/gen_experiments/__init__.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from pysindy import BaseDifferentiation, FiniteDifference, SINDy # type: ignore +from pysindy import BaseDifferentiation, FiniteDifference, SINDy from . import gridsearch, odes, pdes from .utils import SINDyTrialData, make_model # noqa: F401 diff --git a/src/gen_experiments/config.py b/src/gen_experiments/config.py index 47547ba..b90ece4 100644 --- a/src/gen_experiments/config.py +++ b/src/gen_experiments/config.py @@ -1,11 +1,19 @@ +from collections.abc import Iterable from typing import TypeVar import numpy as np import pysindy as ps from gen_experiments.data import _signal_avg_power +from gen_experiments.gridsearch.typing import ( + GridLocator, + SeriesDef, + SeriesList, + SkinnySpecs, +) from gen_experiments.plotting import _PlotPrefs -from gen_experiments.utils import FullSINDyTrialData, NestedDict, SeriesDef, SeriesList +from gen_experiments.typing import NestedDict +from gen_experiments.utils import FullSINDyTrialData T = TypeVar("T") U = TypeVar("U") @@ -41,77 +49,43 @@ def addn(x): plot_prefs = { - "test": _PlotPrefs(True, False, ({"sim_params.t_end": 10},)), + "test": _PlotPrefs(), "test-absrel": _PlotPrefs( - True, _convert_abs_rel_noise, ({"sim_params.noise_abs": 1},) + True, _convert_abs_rel_noise, GridLocator(..., {("sim_params.noise_abs", (1,))}) ), "test-absrel2": _PlotPrefs( True, _convert_abs_rel_noise, - ( - {"sim_params.noise_abs": 0.1}, - {"sim_params.noise_abs": 0.5}, - {"sim_params.noise_abs": 1}, - {"sim_params.noise_abs": 2}, - {"sim_params.noise_abs": 4}, - {"sim_params.noise_abs": 8}, + GridLocator( + ..., + (..., ...), + ( + {"sim_params.noise_abs": 0.1}, + {"sim_params.noise_abs": 0.5}, + {"sim_params.noise_abs": 1}, + {"sim_params.noise_abs": 2}, + {"sim_params.noise_abs": 4}, + {"sim_params.noise_abs": 8}, + ), ), ), - "test-absrel3": _PlotPrefs( + "absrel-newloc": _PlotPrefs( True, _convert_abs_rel_noise, - ( - { - "sim_params.noise_abs": 1, - "diff_params.smoother_kws.window_length": 15, - }, - {"sim_params.noise_abs": 1, "diff_params.meas_var": 1}, - {"sim_params.noise_abs": 1, "diff_params.alpha": 1e-2}, + GridLocator( + ["coeff_mse", "coeff_f1"], + (..., (2, 3, 4)), + ( + {"diff_params.kind": "kalman", "diff_params.alpha": None}, + { + "diff_params.kind": "kalman", + "diff_params.alpha": lambda a: isinstance(a, int), + }, + {"diff_params.kind": "trend_filtered"}, + {"diff_params.diffcls": "SmoothedFiniteDifference"}, + ), ), ), - "test-absrel4": _PlotPrefs( - True, - _convert_abs_rel_noise, - ( - { - "sim_params.noise_abs": 1, - "diff_params.smoother_kws.window_length": 15, - }, - {"sim_params.noise_abs": 1, "diff_params.meas_var": 1}, - {"sim_params.noise_abs": 1, "diff_params.alpha": 1e0}, - { - "sim_params.noise_abs": 2, - "diff_params.smoother_kws.window_length": 15, - }, - {"sim_params.noise_abs": 2, "diff_params.meas_var": 4}, - {"sim_params.noise_abs": 2, "diff_params.alpha": 1e-1}, - ), - ), - "test-absrel5": _PlotPrefs( - True, - _convert_abs_rel_noise, - ( - { - "sim_params.noise_abs": 1, - "diff_params.diffcls": "SmoothedFiniteDifference", - }, - {"sim_params.noise_abs": 1, "diff_params.kind": "kalman"}, - {"sim_params.noise_abs": 1, "diff_params.kind": "trend_filtered"}, - { - "sim_params.noise_abs": 2, - "diff_params.diffcls": "SmoothedFiniteDifference", - }, - {"sim_params.noise_abs": 2, "diff_params.kind": "kalman"}, - {"sim_params.noise_abs": 2, "diff_params.kind": "trend_filtered"}, - { - "sim_params.noise_abs": 4, - "diff_params.diffcls": "SmoothedFiniteDifference", - }, - {"sim_params.noise_abs": 4, "diff_params.kind": "kalman"}, - {"sim_params.noise_abs": 4, "diff_params.kind": "trend_filtered"}, - ), - {(0, 2), (3, 2), (0, 3), (3, 3), (0, 4), (3, 4)}, - ), } sim_params = { "test": ND({"n_trajectories": 2}), @@ -301,7 +275,7 @@ def addn(x): "duration-absnoise": ["sim_params.t_end", "sim_params.noise_abs"], "rel_noise": ["sim_params.t_end", "sim_params.noise_rel"], } -grid_vals = { +grid_vals: dict[str, list[Iterable]] = { "test": [[5, 10, 15, 20]], "abs_noise": [[0.1, 0.5, 1, 2, 4, 8]], "abs_noise-kalman": [[0.1, 0.5, 1, 2, 4, 8], [0.1, 0.5, 1, 2, 4, 8]], @@ -319,7 +293,7 @@ def addn(x): "lorenzk": ["plot", "plot", "max"], "plot2": ["plot", "plot"], } -diff_series = { +diff_series: dict[str, SeriesDef] = { "kalman1": SeriesDef( "Kalman", diff_params["kalman"], @@ -375,7 +349,7 @@ def addn(x): [[5, 8, 12, 15]], ), } -series_params = { +series_params: dict[str, SeriesList] = { "test": SeriesList( "diff_params", "Differentiation Method", @@ -434,7 +408,7 @@ def addn(x): } -skinny_specs = { +skinny_specs: dict[str, SkinnySpecs] = { "exp3": ( ("sim_params.noise_abs", "diff_params.meas_var"), ((identity,), (identity,)), diff --git a/src/gen_experiments/data.py b/src/gen_experiments/data.py index 0d651f2..0fc9760 100644 --- a/src/gen_experiments/data.py +++ b/src/gen_experiments/data.py @@ -7,7 +7,8 @@ import numpy as np import scipy -from gen_experiments.utils import Float1D, Float2D, GridsearchResultDetails +from gen_experiments.gridsearch.typing import GridsearchResultDetails +from gen_experiments.utils import Float1D, Float2D INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12} TRIALS_FOLDER = Path(__file__).parent.absolute() / "trials" @@ -56,7 +57,7 @@ def gen_data( noise_abs = 0.1 rng = np.random.default_rng(seed) if x0_center is None: - x0_center = np.zeros((n_coord)) + x0_center = np.zeros((n_coord), dtype=np.float_) t_train = np.arange(0, t_end, dt, dtype=np.float_) t_train_span = (t_train[0], t_train[-1]) if nonnegative: diff --git a/src/gen_experiments/debug.py b/src/gen_experiments/debug.py deleted file mode 100644 index a1193ca..0000000 --- a/src/gen_experiments/debug.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Annotated, Generic, TypedDict, TypeVar - -import numpy as np -from numpy.typing import DTypeLike, NBitBase, NDArray - -# T = TypeVar("T") - -# class Foo[T]: -# items: list[T] - -# def __init__(self, thing: T): -# self.items = [thing, thing] - -# Bar = - - -T = TypeVar("T", bound=np.generic) -Foo = NDArray[T] -Bar = Annotated[NDArray, "foobar"] - -lil_foo = NDArray[np.void] - - -def baz(qux: Foo[np.void]): - pass diff --git a/src/gen_experiments/gridsearch.py b/src/gen_experiments/gridsearch.py deleted file mode 100644 index e7aa488..0000000 --- a/src/gen_experiments/gridsearch.py +++ /dev/null @@ -1,380 +0,0 @@ -from copy import copy -from functools import partial -from logging import getLogger -from pprint import pformat -from typing import Annotated, Callable, Iterable, Optional, Sequence, TypeVar - -import matplotlib.pyplot as plt -import numpy as np -from numpy.typing import DTypeLike, NDArray -from scipy.stats import kstest - -import gen_experiments -from gen_experiments import config -from gen_experiments.odes import plot_ode_panel -from gen_experiments.plotting import _PlotPrefs -from gen_experiments.utils import ( - GridsearchResult, - GridsearchResultDetails, - NestedDict, - SavedData, - SeriesDef, - SeriesList, - SINDyTrialData, - _amax_to_full_inds, - _argopt, - _grid_locator_match, - simulate_test_data, -) - -pformat = partial(pformat, indent=4, sort_dicts=True) -logger = getLogger(__name__) -name = "gridsearch" -lookup_dict = vars(config) - -OtherSliceDef = tuple[int | Callable] -SkinnySpecs = Optional[tuple[tuple[str, ...], tuple[OtherSliceDef, ...]]] - - -def run( - seed: int, - group: str, - grid_params: list[str], - grid_vals: list[Sequence], - grid_decisions: Sequence[str], - other_params: dict, - series_params: Optional[SeriesList] = None, - metrics: Optional[Sequence[str]] = None, - plot_prefs: _PlotPrefs = _PlotPrefs(True, False, ()), - skinny_specs: SkinnySpecs = None, -) -> GridsearchResultDetails: - """Run a grid-search wrapper of an experiment. - - Arguments: - group: an experiment registered in gen_experiments. It must - have a name and a metric_ordering attribute - grid_params: kwarg names to grid and pass to - experiment - grid_vals: kwarg values to grid. Indices match grid_params - grid_decisions: What to do with each grid param, e.g. - {"plot", "best"}. Indices match grid_params. - other_params: a dict of other kwargs to pass to experiment - metrics: names of metrics to record from each wrapped experiment - plot_prefs: whether to plot results, and if so, a function to - intercept and modify plot data. Use this for applying any - scaling or conversions. - skinny_specs: Allow only conducting some of the grid search, - where axes are all searched, but not all combinations are - searched. The first element is a sequence of grid_names to - skinnify. The second is the thin_slices criteria (see - docstring for _ndindex_skinny). By default, all plot axes - are made skinny with respect to each other. - """ - other_params = NestedDict(**other_params) - base_ex, base_group = gen_experiments.experiments[group] - if series_params is None: - series_params = SeriesList(None, None, [SeriesDef(group, {}, [], [])]) - legends = False - else: - legends = True - n_metrics = len(metrics) - metric_ordering = [base_ex.metric_ordering[metric] for metric in metrics] - n_plotparams = len([decide for decide in grid_decisions if decide == "plot"]) - series_searches: list[tuple[list[GridsearchResult], list[GridsearchResult]]] = [] - intermediate_data: list[SavedData] = [] - plot_data: list[SavedData] = [] - if base_group is not None: - other_params["group"] = base_group - for s_counter, series_data in enumerate(series_params.series_list): - curr_other_params = copy(other_params) - if series_params.param_name is not None: - curr_other_params[series_params.param_name] = series_data.static_param - new_grid_vals: list = grid_vals + series_data.grid_vals - new_grid_params = grid_params + series_data.grid_params - new_grid_decisions = grid_decisions + len(series_data.grid_params) * ["best"] - if skinny_specs is not None: - ind_skinny, where_others = _curr_skinny_specs(skinny_specs, new_grid_params) - else: - ind_skinny = [ - ind for ind, decide in enumerate(new_grid_decisions) if decide == "plot" - ] - where_others = None - full_results_shape = (len(metrics), *(len(grid) for grid in new_grid_vals)) - full_results = np.full(full_results_shape, np.nan) - gridpoint_selector = _ndindex_skinny( - full_results_shape[1:], ind_skinny, where_others - ) - rng = np.random.default_rng(seed) - for ind_counter, ind in enumerate(gridpoint_selector): - print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r") - new_seed = rng.integers(1000) - param_updates = {} - for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): - param_updates[key] = val_list[axis_ind] - curr_other_params.update(param_updates) - curr_results, grid_data = base_ex.run( - new_seed, **curr_other_params, display=False, return_all=True - ) - grid_data: SINDyTrialData - intermediate_data.append( - {"params": curr_other_params.flatten(), "pind": ind, "data": grid_data} - ) - full_results[(slice(None), *ind)] = [ - curr_results[metric] for metric in metrics - ] - grid_optima, grid_ind = _marginalize_grid_views( - new_grid_decisions, full_results, metric_ordering - ) - series_searches.append((grid_optima, grid_ind)) - - if plot_prefs: - full_m_inds = _amax_to_full_inds( - plot_prefs.grid_ind_match, [s[1] for s in series_searches] - ) - for int_data in intermediate_data: - logger.debug( - f"Checking whether to save/plot :\n{pformat(int_data['params'])}\n" - f"\tat location {pformat(int_data['pind'])}\n" - f"\tagainst spec: {pformat(plot_prefs.grid_params_match)}\n" - f"\twith allowed locations {pformat(full_m_inds)}" - ) - if _grid_locator_match( - int_data["params"], - int_data["pind"], - plot_prefs.grid_params_match, - full_m_inds, - ) and int_data["params"] not in [saved["params"] for saved in plot_data]: - grid_data = int_data["data"] - print("Results for params: ", int_data["params"], flush=True) - grid_data |= simulate_test_data( - grid_data["model"], grid_data["dt"], grid_data["x_test"] - ) - logger.info("Found match, simulating and plotting") - plot_ode_panel(grid_data) - plot_data.append(int_data) - if plot_prefs.rel_noise: - grid_vals, grid_params = plot_prefs.rel_noise( - grid_vals, grid_params, grid_data - ) - fig, subplots = plt.subplots( - n_metrics, - n_plotparams, - sharey="row", - sharex="col", - squeeze=False, - figsize=(n_plotparams * 3, 0.5 + n_metrics * 2.25), - ) - for series_data, series_name in zip( - series_searches, (ser.name for ser in series_params.series_list) - ): - plot( - fig, - subplots, - metrics, - grid_params, - grid_vals, - series_data[0], - series_name, - legends, - ) - if series_params.print_name is not None: - title = f"Grid Search on {series_params.print_name} in {group}" - else: - title = f"Grid Search in {group}" - fig.suptitle(title) - fig.tight_layout() - - main_metric_ind = metrics.index("main") if "main" in metrics else 0 - return { - "system": group, - "plot_data": plot_data, - "series_data": { - name: data - for data, name in zip( - [list(zip(metrics, argopts)) for metrics, argopts in series_searches], - [ser.name for ser in series_params.series_list], - ) - }, - "metrics": metrics, - "grid_params": grid_params, - "grid_vals": grid_vals, - "main": max( - grid[main_metric_ind].max() - for metrics, _ in series_searches - for grid in metrics - ), - } - - -def plot( - fig: plt.Figure, - subplots: Sequence[plt.Axes], - metrics: Sequence[str], - grid_params: Sequence[str], - grid_vals: Sequence[Sequence[float] | np.ndarray], - grid_searches: Sequence[GridsearchResult], - name: str, - legends: bool, -): - for m_ind_row, m_name in enumerate(metrics): - for col, (param_name, x_ticks, param_search) in enumerate( - zip(grid_params, grid_vals, grid_searches) - ): - ax = subplots[m_ind_row, col] - ax.plot(x_ticks, param_search[m_ind_row], label=name) - x_ticks = np.array(x_ticks) - if m_name in ("coeff_mse", "coeff_mae"): - ax.set_yscale("log") - x_ticks_normalized = (x_ticks - x_ticks.min()) / ( - x_ticks.max() - x_ticks.min() - ) - x_ticks_lognormalized = (np.log(x_ticks) - np.log(x_ticks).min()) / ( - np.log(x_ticks.max()) - np.log(x_ticks).min() - ) - ax = subplots[m_ind_row, col] - if kstest(x_ticks_normalized, "uniform") < kstest( - x_ticks_lognormalized, "uniform" - ): - ax.set_xscale("log") - if m_ind_row == 0: - ax.set_title(f"{param_name}") - if col == 0: - ax.set_ylabel(f"{m_name}") - if legends: - ax.legend() - - -T = TypeVar("T", bound=np.generic) - - -def _marginalize_grid_views( - grid_decisions: Iterable[str], - results: Annotated[NDArray[T], "shape (n_metrics, *n_gridsearch_values)"], - max_or_min: Sequence[str] = None, -) -> tuple[list[GridsearchResult[T]], list[GridsearchResult]]: - """Marginalize unnecessary dimensions by taking max across axes. - - Ignores NaN values - Args: - grid_decisions: list of how to treat each non-metric gridsearch - axis. An array of metrics for each "plot" grid decision - will be returned, along with an array of the the index - of collapsed dimensions that returns that metric - results: An array of shape (n_metrics, *n_gridsearch_values) - max_or_min: either "max" or "min" for each row of results - Returns: - a list of the metric optima for each plottable grid decision, and - a list of the flattened argoptima. - """ - arg_dtype: DTypeLike = ",".join(results.ndim * "i") - plot_param_inds = [ind for ind, val in enumerate(grid_decisions) if val == "plot"] - grid_searches = [] - args_maxes = [] - optfuns = [np.nanmax if opt == "max" else np.nanmin for opt in max_or_min] - for param_ind in plot_param_inds: - reduce_axes = tuple(set(range(results.ndim - 1)) - {param_ind}) - selection_results = np.array( - [opt(result, axis=reduce_axes) for opt, result in zip(optfuns, results)] - ) - sub_arrs = [] - for m_ind, (result, opt) in enumerate(zip(results, max_or_min)): - pad_m_ind = np.vectorize(lambda tp: np.void((m_ind, *tp), dtype=arg_dtype)) - arg_max = pad_m_ind(_argopt(result, reduce_axes, opt)) - sub_arrs.append(arg_max) - - args_max = np.stack(sub_arrs) - grid_searches.append(selection_results) - args_maxes.append(args_max) - return grid_searches, args_maxes - - -def _ndindex_skinny( - shape: tuple[int], - thin_axes: Optional[Sequence[int]] = None, - thin_slices: Optional[Sequence[OtherSliceDef]] = None, -): - """ - Return an iterator like ndindex, but only traverse thin_axes once - - This is useful for grid searches with multiple plot axes, where - searching across all combinations of plot axes is undesirable. - Slow for big arrays! (But still probably trivial compared to the - gridsearch operation :)) - - Args: - shape: array shape - thin_axes: axes for which you don't want the product of all - indexes - thin_slices: the indexes for other thin axes when traversing - a particular thin axis. Defaults to 0th index - - Example: - - >>> set(_ndindex_skinny((2,2), (0,1), ((0,), (lambda x: x,)))) - - {(0, 0), (0, 1), (1, 1)} - """ - if thin_axes is None and thin_slices is None: - thin_axes = () - thin_slices = () - elif thin_axes is None: - raise ValueError("Must pass thin_axes if thin_slices is not None") - elif thin_slices is None: # slice other thin axes at 0th index - n_thin = len(thin_axes) - thin_slices = n_thin * ((n_thin - 1) * (0,),) - full_indexes = np.ndindex(shape) - - def ind_checker(multi_index): - """Check if a multi_index meets thin index criteria""" - matches = [] - # check whether multi_index matches criteria of any thin_axis - for ax1, where_others in zip(thin_axes, thin_slices, strict=True): - other_axes = list(thin_axes) - other_axes.remove(ax1) - match = True - # check whether multi_index meets criteria of a particular thin_axis - for ax2, slice_ind in zip(other_axes, where_others, strict=True): - if callable(slice_ind): - slice_ind = slice_ind(multi_index[ax1]) - # would check: "== slice_ind", but must allow slice_ind = -1 - match *= multi_index[ax2] == range(shape[ax2])[slice_ind] - matches.append(match) - return any(matches) - - while True: - try: - ind = next(full_indexes) - except StopIteration: - break - if ind_checker(ind): - yield ind - - -def _curr_skinny_specs( - skinny_specs: SkinnySpecs, grid_params: list[str] -) -> tuple[Sequence[int], Sequence[OtherSliceDef]]: - """Calculate which skinny specs apply to current parameters""" - skinny_param_inds = [ - grid_params.index(pname) for pname in skinny_specs[0] if pname in grid_params - ] - missing_sk_inds = [ - skinny_specs[0].index(pname) - for pname in skinny_specs[0] - if pname not in grid_params - ] - where_others = [] - for orig_sk_ind, match_criteria in zip( - range(len(skinny_specs[0])), skinny_specs[1], strict=True - ): - if orig_sk_ind in missing_sk_inds: - continue - missing_criterion_inds = tuple( - sk_ind if sk_ind < orig_sk_ind else sk_ind - 1 for sk_ind in missing_sk_inds - ) - new_criteria = tuple( - match_criterion - for cr_ind, match_criterion in enumerate(match_criteria) - if cr_ind not in missing_criterion_inds - ) - where_others.append(new_criteria) - return skinny_param_inds, tuple(where_others) diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py new file mode 100644 index 0000000..941d17d --- /dev/null +++ b/src/gen_experiments/gridsearch/__init__.py @@ -0,0 +1,669 @@ +from collections.abc import Iterable +from copy import copy +from functools import partial +from logging import getLogger +from pprint import pformat +from types import EllipsisType as ellipsis +from typing import ( + Annotated, + Any, + Callable, + Collection, + Optional, + Sequence, + TypeVar, + Union, + cast, +) +from warnings import warn + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from numpy.typing import DTypeLike, NDArray +from scipy.stats import kstest + +import gen_experiments +from gen_experiments import config +from gen_experiments.odes import plot_ode_panel +from gen_experiments.plotting import _PlotPrefs +from gen_experiments.typing import FloatND, NestedDict +from gen_experiments.utils import simulate_test_data + +from .typing import ( + ExpResult, + GridLocator, + GridsearchResult, + GridsearchResultDetails, + OtherSliceDef, + SavedGridPoint, + SeriesDef, + SeriesList, + SkinnySpecs, +) + +pformat = partial(pformat, indent=4, sort_dicts=True) +logger = getLogger(__name__) +name = "gridsearch" +lookup_dict = vars(config) + + +def _amax_to_full_inds( + amax_inds: Collection[tuple[int | slice, int] | ellipsis] | ellipsis, + amax_arrays: list[list[GridsearchResult[np.void]]], +) -> set[tuple[int, ...]]: + """Find full indexers to selected elements of argmax arrays + + Args: + amax_inds: selection statemtent of which argmaxes to return. + amax_arrays: arrays of indexes to full gridsearch that are responsible for + the computed max values. First level of nesting reflects series, second + level reflects which grid grid axis. + Returns: + all indexers to full gridsearch that are requested by amax_inds + """ + + if amax_inds is ...: # grab each element from arrays in list of lists of arrays + return { + np_to_primitive(el) + for ar_list in amax_arrays + for arr in ar_list + for el in arr.flatten() + } + all_inds = set() + for plot_axis_results in [el for series in amax_arrays for el in series]: + for ind in amax_inds: + if ind is ...: # grab each element from arrays in list of lists of arrays + all_inds |= { + np_to_primitive(el) + for ar_list in amax_arrays + for arr in ar_list + for el in arr.flatten() + } + elif isinstance(ind[0], int): + all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))} + else: # ind[0] is slice(None) + all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]} + return all_inds + + +_EqTester = TypeVar("_EqTester") + + +def _param_normalize(val: _EqTester) -> _EqTester | str: + """Allow equality testing of mutable objects with useful reprs""" + if type(val).__eq__ == object.__eq__: + return repr(val) + else: + return val + + +def _grid_locator_match( + exp_params: dict[str, Any], + exp_ind: tuple[int, ...], + param_spec: Collection[dict[str, Any]], + ind_spec: Collection[tuple[int, ...]], +) -> bool: + """Determine whether experimental parameters match a specification + + Logical clause applied is: + + OR((exp_params MATCHES params for params in param_spec)) + AND + OR((exp_ind MATCHES ind for ind in ind_spec)) + + Treats OR of an empty collection as falsy + Args: + exp_params: the experiment parameters to evaluate + exp_ind: the experiemnt's full-size grid index to evaluate + param_spec: the criteria for matching exp_params + ind_spec: the criteria for matching exp_ind + """ + warn("Use find_gridpoints() instead", DeprecationWarning) + found_match = False + for params_or in param_spec: + params_or = {k: _param_normalize(v) for k, v in params_or.items()} + + try: + if all( + _param_normalize(exp_params[param]) == value + for param, value in params_or.items() + ): + found_match = True + break + except KeyError: + pass + for ind_or in ind_spec: + # exp_ind doesn't include metric, so skip first metric + if _index_in(exp_ind, ind_or[1:]): + break + else: + return False + return found_match + + +def run( + seed: int, + group: str, + grid_params: list[str], + grid_vals: list[Sequence], + grid_decisions: list[str], + other_params: dict, + skinny_specs: Optional[SkinnySpecs] = None, + series_params: Optional[SeriesList] = None, + metrics: tuple[str, ...] = (), + plot_prefs: _PlotPrefs = _PlotPrefs(), +) -> GridsearchResultDetails: + """Run a grid-search wrapper of an experiment. + + Arguments: + group: an experiment registered in gen_experiments. It must + have a name and a metric_ordering attribute + grid_params: kwarg names to grid and pass to + experiment + grid_vals: kwarg values to grid. Indices match grid_params + grid_decisions: What to do with each grid param, e.g. + {"plot", "best"}. Indices match grid_params. + other_params: a dict of other kwargs to pass to experiment + metrics: names of metrics to record from each wrapped experiment + plot_prefs: whether to plot results, and if so, a function to + intercept and modify plot data. Use this for applying any + scaling or conversions. + skinny_specs: Allow only conducting some of the grid search, + where axes are all searched, but not all combinations are + searched. The first element is a sequence of grid_names to + skinnify. The second is the thin_slices criteria (see + docstring for _ndindex_skinny). By default, all plot axes + are made skinny with respect to each other. + """ + other_params = NestedDict(**other_params) + base_ex, base_group = gen_experiments.experiments[group] + if series_params is None: + series_params = SeriesList(None, None, [SeriesDef(group, {}, [], [])]) + legends = False + else: + legends = True + n_metrics = len(metrics) + metric_ordering = [base_ex.metric_ordering[metric] for metric in metrics] + n_plotparams = len([decide for decide in grid_decisions if decide == "plot"]) + series_searches: list[tuple[list[GridsearchResult], list[GridsearchResult]]] = [] + intermediate_data: list[SavedGridPoint] = [] + plot_data: list[SavedGridPoint] = [] + if base_group is not None: + other_params["group"] = base_group + for s_counter, series_data in enumerate(series_params.series_list): + curr_other_params = copy(other_params) + if series_params.param_name is not None: + curr_other_params[series_params.param_name] = series_data.static_param + new_grid_vals: list = grid_vals + series_data.grid_vals + new_grid_params = grid_params + series_data.grid_params + new_grid_decisions = grid_decisions + len(series_data.grid_params) * ["best"] + if skinny_specs is not None: + ind_skinny, where_others = _curr_skinny_specs(skinny_specs, new_grid_params) + else: + ind_skinny = [ + ind for ind, decide in enumerate(new_grid_decisions) if decide == "plot" + ] + where_others = None + full_results_shape = (len(metrics), *(len(grid) for grid in new_grid_vals)) + full_results = np.full(full_results_shape, np.nan) + gridpoint_selector = _ndindex_skinny( + full_results_shape[1:], ind_skinny, where_others + ) + rng = np.random.default_rng(seed) + for ind_counter, ind in enumerate(gridpoint_selector): + print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r") + new_seed = rng.integers(1000) + for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): + curr_other_params[key] = val_list[axis_ind] + curr_results, grid_data = base_ex.run( + new_seed, **curr_other_params, display=False, return_all=True + ) + intermediate_data.append( + {"params": curr_other_params.flatten(), "pind": ind, "data": grid_data} + ) + full_results[(slice(None), *ind)] = [ + curr_results[metric] for metric in metrics + ] + grid_optima, grid_ind = _marginalize_grid_views( + new_grid_decisions, full_results, metric_ordering + ) + series_searches.append((grid_optima, grid_ind)) + + main_metric_ind = metrics.index("main") if "main" in metrics else 0 + results: GridsearchResultDetails = { + "system": group, + "plot_data": [], + "series_data": { + name: data + for data, name in zip( + [list(zip(metrics, argopts)) for metrics, argopts in series_searches], + [ser.name for ser in series_params.series_list], + ) + }, + "metrics": metrics, + "grid_params": grid_params, + "plot_params": [ + param + for decide, param in zip(grid_decisions, grid_params) + if decide == "plot" + ], + "grid_vals": grid_vals, + "main": max( + grid[main_metric_ind].max() + for metrics, _ in series_searches + for grid in metrics + ), + } + if plot_prefs: + plot_data = find_gridpoints(plot_prefs.plot_match, intermediate_data, results) + results["plot_data"] = plot_data + for gridpoint in plot_data: + grid_data = gridpoint["data"] + logger.info(f"Plotting: {gridpoint['params']}") + grid_data |= simulate_test_data( + grid_data["model"], grid_data["dt"], grid_data["x_test"] + ) + plot_ode_panel(grid_data) # type: ignore + if plot_prefs.rel_noise: + grid_vals, grid_params = plot_prefs.rel_noise( + grid_vals, grid_params, grid_data + ) + fig, subplots = plt.subplots( + n_metrics, + n_plotparams, + sharey="row", + sharex="col", + squeeze=False, + figsize=(n_plotparams * 3, 0.5 + n_metrics * 2.25), + ) + for series_data, series_name in zip( + series_searches, (ser.name for ser in series_params.series_list) + ): + plot( + subplots, + metrics, + grid_params, + grid_vals, + series_data[0], + series_name, + legends, + ) + if series_params.print_name is not None: + title = f"Grid Search on {series_params.print_name} in {group}" + else: + title = f"Grid Search in {group}" + fig.suptitle(title) + fig.tight_layout() + + return results + + +def plot( + subplots: NDArray[Annotated[np.void, "Axes"]], + metrics: Sequence[str], + grid_params: Sequence[str], + grid_vals: Sequence[Sequence[float] | np.ndarray], + grid_searches: Sequence[GridsearchResult], + name: str, + legends: bool, +): + if len(metrics) == 0: + raise ValueError("Nothing to plot") + for m_ind_row, m_name in enumerate(metrics): + for col, (param_name, x_ticks, param_search) in enumerate( + zip(grid_params, grid_vals, grid_searches) + ): + ax = cast(Axes, subplots[m_ind_row, col]) + ax.plot(x_ticks, param_search[m_ind_row], label=name) + x_ticks = np.array(x_ticks) + if m_name in ("coeff_mse", "coeff_mae"): + ax.set_yscale("log") + x_ticks_normalized = (x_ticks - x_ticks.min()) / ( + x_ticks.max() - x_ticks.min() + ) + x_ticks_lognormalized = (np.log(x_ticks) - np.log(x_ticks).min()) / ( + np.log(x_ticks.max()) - np.log(x_ticks).min() + ) + ax = subplots[m_ind_row, col] + if kstest(x_ticks_normalized, "uniform") < kstest( + x_ticks_lognormalized, "uniform" + ): + ax.set_xscale("log") + if m_ind_row == 0: + ax.set_title(f"{param_name}") + if col == 0: + ax.set_ylabel(f"{m_name}") + if legends: + ax.legend() # type: ignore + + +T = TypeVar("T", bound=np.generic) + + +def _argopt( + arr: FloatND, axis: Optional[int | tuple[int, ...]] = None, opt: str = "max" +) -> NDArray[np.void]: + """Calculate the argmax/min, but accept tuple axis. + + Ignores NaN values + + Args: + arr: an array to search + axis: The axis or axes to search through for the argmax/argmin. + opt: One of {"max", "min"} + + Returns: + array of indices for the argopt. If m = arr.ndim and n = len(axis), + the final result will be an array of ndim = m-n with elements being + tuples of length m + """ + dtype: DTypeLike = [(f"f{axind}", "i") for axind in range(arr.ndim)] + if axis is None: + axis = () + axis = (axis,) if isinstance(axis, int) else axis + keep_axes = tuple(sorted(set(range(arr.ndim)) - set(axis))) + keep_shape = tuple(arr.shape[ax] for ax in keep_axes) + result = np.empty(keep_shape, dtype=dtype) + optfun = np.nanargmax if opt == "max" else np.nanargmin + for slise in np.ndindex(keep_shape): + sub_arr = arr + # since we shrink shape, we need to chop of axes from the end + for ind, ax in zip(reversed(slise), reversed(keep_axes)): + sub_arr = np.take(sub_arr, ind, ax) + subind_max = np.unravel_index(optfun(sub_arr), sub_arr.shape) + fullind_max = np.empty((arr.ndim), int) + fullind_max[np.array(keep_axes, int)] = slise + fullind_max[np.array(axis, int)] = subind_max + result[slise] = tuple(fullind_max) + return result + + +def _marginalize_grid_views( + grid_decisions: Iterable[str], + results: Annotated[NDArray[T], "shape (n_metrics, *n_gridsearch_values)"], + max_or_min: Sequence[str], +) -> tuple[list[GridsearchResult[T]], list[GridsearchResult]]: + """Marginalize unnecessary dimensions by taking max across axes. + + Ignores NaN values + Args: + grid_decisions: list of how to treat each non-metric gridsearch + axis. An array of metrics for each "plot" grid decision + will be returned, along with an array of the the index + of collapsed dimensions that returns that metric + results: An array of shape (n_metrics, *n_gridsearch_values) + max_or_min: either "max" or "min" for each row of results + Returns: + a list of the metric optima for each plottable grid decision, and + a list of the flattened argoptima. + """ + arg_dtype = np.dtype(",".join(results.ndim * "i")) + plot_param_inds = [ind for ind, val in enumerate(grid_decisions) if val == "plot"] + grid_searches = [] + args_maxes = [] + optfuns = [np.nanmax if opt == "max" else np.nanmin for opt in max_or_min] + for param_ind in plot_param_inds: + reduce_axes = tuple(set(range(results.ndim - 1)) - {param_ind}) + selection_results = np.array( + [opt(result, axis=reduce_axes) for opt, result in zip(optfuns, results)] + ) + sub_arrs = [] + for m_ind, (result, opt) in enumerate(zip(results, max_or_min)): + + def _metric_pad(tp: tuple[int, ...]) -> np.void: + return np.void((m_ind, *tp), dtype=arg_dtype) + + pad_m_ind = np.vectorize(_metric_pad) + arg_max = pad_m_ind(_argopt(result, reduce_axes, opt)) + sub_arrs.append(arg_max) + + args_max = np.stack(sub_arrs) + grid_searches.append(selection_results) + args_maxes.append(args_max) + return grid_searches, args_maxes + + +def _ndindex_skinny( + shape: tuple[int, ...], + thin_axes: Optional[Sequence[int]] = None, + thin_slices: Optional[Sequence[OtherSliceDef]] = None, +): + """ + Return an iterator like ndindex, but only traverse thin_axes once + + This is useful for grid searches with multiple plot axes, where + searching across all combinations of plot axes is undesirable. + Slow for big arrays! (But still probably trivial compared to the + gridsearch operation :)) + + Args: + shape: array shape + thin_axes: axes for which you don't want the product of all + indexes + thin_slices: the indexes for other thin axes when traversing + a particular thin axis. Defaults to 0th index + + Example: + + >>> set(_ndindex_skinny((2,2), (0,1), ((0,), (lambda x: x,)))) + + {(0, 0), (0, 1), (1, 1)} + """ + if thin_axes is None and thin_slices is None: + thin_axes = () + thin_slices = () + elif thin_axes is None: + raise ValueError("Must pass thin_axes if thin_slices is not None") + elif thin_slices is None: # slice other thin axes at 0th index + n_thin = len(thin_axes) + thin_slices = n_thin * ((n_thin - 1) * (0,),) + full_indexes = np.ndindex(shape) + thin_slices = cast(Sequence[OtherSliceDef], thin_slices) + + def ind_checker(multi_index: tuple[int, ...]) -> bool: + """Check if a multi_index meets thin index criteria""" + matches = [] + # check whether multi_index matches criteria of any thin_axis + for ax1, where_others in zip(thin_axes, thin_slices, strict=True): + other_axes = list(thin_axes) + other_axes.remove(ax1) + match = True + # check whether multi_index meets criteria of a particular thin_axis + for ax2, slice_ind in zip(other_axes, where_others, strict=True): + if callable(slice_ind): + slice_ind = slice_ind(multi_index[ax1]) + # would check: "== slice_ind", but must allow slice_ind = -1 + match &= multi_index[ax2] == range(shape[ax2])[slice_ind] + matches.append(match) + return any(matches) + + while True: + try: + ind = next(full_indexes) + except StopIteration: + break + if ind_checker(ind): + yield ind + + +def _curr_skinny_specs( + skinny_specs: SkinnySpecs, grid_params: list[str] +) -> tuple[Sequence[int], Sequence[OtherSliceDef]]: + """Calculate which skinny specs apply to current parameters""" + skinny_param_inds = [ + grid_params.index(pname) for pname in skinny_specs[0] if pname in grid_params + ] + missing_sk_inds = [ + skinny_specs[0].index(pname) + for pname in skinny_specs[0] + if pname not in grid_params + ] + where_others = [] + for orig_sk_ind, match_criteria in zip( + range(len(skinny_specs[0])), skinny_specs[1], strict=True + ): + if orig_sk_ind in missing_sk_inds: + continue + missing_criterion_inds = tuple( + sk_ind if sk_ind < orig_sk_ind else sk_ind - 1 for sk_ind in missing_sk_inds + ) + new_criteria = tuple( + match_criterion + for cr_ind, match_criterion in enumerate(match_criteria) + if cr_ind not in missing_criterion_inds + ) + where_others.append(new_criteria) + return skinny_param_inds, tuple(where_others) + + +def strict_find_grid_match( + results: GridsearchResultDetails, + *, + params: Optional[dict[str, Any]] = None, + ind_spec: Optional[tuple[int | slice, int] | ellipsis] = None, +) -> ExpResult: + if params is None: + params = {} + if ind_spec is None: + ind_spec = ... + matches = [] + amax_arrays = [ + [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] + for _, single_series_all_axes in results["series_data"].items() + ] + full_inds = _amax_to_full_inds((ind_spec,), amax_arrays) + + for trajectory in results["plot_data"]: + if _grid_locator_match( + trajectory["params"], trajectory["pind"], (params,), full_inds + ): + matches.append(trajectory) + + if len(matches) > 1: + raise ValueError("Specification is nonunique; matched multiple results") + if len(matches) == 0: + raise ValueError("Could not find a match") + return matches[0]["data"] + + +def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> bool: + """Determine whether base indexing tuple will match given numpy index""" + if len(base) > len(tgt): + return False + curr_ax = 0 + for ax, ind in enumerate(tgt): + if isinstance(ind, int): + try: + if ind != base[curr_ax]: + return False + except IndexError: + return False + elif isinstance(ind, slice): + if not (ind.start is None and ind.stop is None and ind.step is None): + raise ValueError("Only slices allowed are `slice(None)`") + elif ind is ...: + base_ind_remaining = len(base) - curr_ax + tgt_ind_remaining = len(tgt) - ax + # ellipsis can take 0 or more spots + curr_ax += max(base_ind_remaining - tgt_ind_remaining, -1) + curr_ax += 1 + if curr_ax == len(base): + return True + return False + + +def find_gridpoints( + find: GridLocator, where: list[SavedGridPoint], context: GridsearchResultDetails +) -> list[SavedGridPoint]: + """Find results wrapped by gridsearch that match criteria + + Args: + find: the criteria + where: The list of saved gridpoints to search + context: The overall data for the gridsearch, describing metrics, grid + setup, and gridsearch results + + Returns: + A list of the matching points in the gridsearch. + """ + + results: list[SavedGridPoint] = [] + partial_match: set[tuple[int, ...]] = set() + if find.metrics is ...: + inds_of_metrics = range(len(context["metrics"])) + else: + inds_of_metrics = tuple( + context["metrics"].index(metric) for metric in find.metrics + ) + ax_sizes = { + plot_ax: len(context["grid_vals"][context["grid_params"].index(plot_ax)]) + for plot_ax in context["plot_params"] + } + if ... in find.keep_axes: + keep_axes = _expand_ellipsis_axis(find.keep_axes, ax_sizes) # type: ignore + else: + keep_axes = cast(Collection[tuple[str, tuple[int, ...]]], find.keep_axes) + # No deduplication is done! + keep_axes = tuple( + (context["grid_params"].index(keep_ax[0]), keep_ax[1]) for keep_ax in keep_axes + ) + + for ser in context["series_data"].values(): + for index_of_ax, indexes_in_ax in keep_axes: + amax_arr = ser[index_of_ax][1] + amax_want = amax_arr[np.ix_(inds_of_metrics, indexes_in_ax)].flatten() + partial_match |= {np_to_primitive(el) for el in amax_want} + logger.info( + f"Found {len(partial_match)} gridpoints that match metric-plot_axis criteria" + ) + + params_or = tuple( + {k: v if callable(v) else _param_normalize(v) for k, v in params_match.items()} + for params_match in find.params_or + ) + + def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool: + if callable(criteria): + return criteria(candidate) + else: + return _param_normalize(candidate) == criteria + + for point in filter(lambda p: p["pind"] in partial_match, where): + logger.debug(f"Checking whether {point['pind']} matches param query") + for params_match in params_or: + if all( + check_values(value, point["params"][param]) + for param, value in params_match.items() + ): + results.append(point) + break + + logger.info(f"found {len(results)} points that match all GridLocator criteria") + return results + + +def _expand_ellipsis_axis( + keep_axis: Union[ + tuple[ellipsis, ellipsis], + tuple[ellipsis, tuple[int, ...]], + tuple[tuple[str, ...], ellipsis], + ], + ax_sizes: dict[str, int], +) -> Collection[tuple[str, tuple[int, ...]]]: + if keep_axis[0] is ... and keep_axis[1] is ...: + # form 1 + return tuple((k, tuple(range(v))) for k, v in ax_sizes.items()) + elif isinstance(keep_axis[1], tuple): + # form 2 + return tuple((k, keep_axis[1]) for k in ax_sizes.keys()) + elif isinstance(keep_axis[0], tuple): + # form 3 + return tuple((k, tuple(range(ax_sizes[k]))) for k in keep_axis[0]) + else: + raise TypeError("Keep_axis does not have an ellipsis or is not a 2-tuple") + + +def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: + """Turn a void that represents a tuple of ints into a tuple of ints""" + return tuple(int(el) for el in cast(Iterable, tuple_like)) diff --git a/src/gen_experiments/gridsearch/typing.py b/src/gen_experiments/gridsearch/typing.py new file mode 100644 index 0000000..7e890f7 --- /dev/null +++ b/src/gen_experiments/gridsearch/typing.py @@ -0,0 +1,190 @@ +from collections.abc import Iterable +from dataclasses import dataclass, field +from types import EllipsisType as ellipsis +from typing import ( + Annotated, + Any, + Callable, + Collection, + Optional, + Sequence, + TypedDict, + TypeVar, + Union, +) + +import numpy as np +from numpy.typing import NDArray + +OtherSliceDef = tuple[(int | Callable[[int], int]), ...] +"""For a particular index of one gridsearch axis, which indexes of other axes +should be included.""" +SkinnySpecs = tuple[tuple[str, ...], tuple[OtherSliceDef, ...]] + + +KeepAxisSpec = Union[ + tuple[ellipsis, ellipsis], # all axes, all indices + tuple[ellipsis, tuple[int, ...]], # all axes, specific indices + tuple[tuple[str, ...], ellipsis], # specific axes, all indices + Collection[tuple[str, tuple[int, ...]]], # specific axes, specific indices +] + + +@dataclass(frozen=True) +class GridLocator: + """A specification of which points in a gridsearch to match. + + Rather than specifying the exact point in the mega-grid of every + varied axis, specify by result, e.g "all of the points from the + Kalman series that had the best mean squared error as noise was + varied. + + Logical AND is applied across the metric, keep_axis, AND param_match + specifications. + + Args: + metric: The metric in which to find results. An ellipsis means "any + metrics" + keep_axis: The grid-varied parameter in which to find results and which + index of values for that parameter. To search a particular value of + that parameter, use the param_match kwarg.It can be specified in + several ways: + (a) a tuple of two ellipses, representing all axes, all indices + (b) a tuple of an ellipsis, representing all axes, and a tuple of + ints for specific indices + (c) a tuple of a tuple of strings for specific axes, and an ellipsis + for all indices + (d) a collection of tuples of a string (specific axis) and tuple of + ints (specific indices) + param_match: A collection of dictionaries to match parameter values represented + by points in the gridsearch. Dictionary equality is checked for every + non-callable value; for callable values, it is applied to the grid + parameters and must return a boolean. For values whose equality is object + equality (often, mutable objects), the repr is used. Logical OR is applied + across the collection. + """ + + metrics: Collection[str] | ellipsis = field(default=...) + keep_axes: KeepAxisSpec = field(default=(..., ...)) + params_or: Collection[dict[str, Any]] = field(default=()) + + +T = TypeVar("T", bound=np.generic) +GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] +SeriesData = Annotated[ + list[ + tuple[ + Annotated[GridsearchResult, "metrics"], + Annotated[GridsearchResult[np.void], "arg_opts"], + ] + ], + "len=n_grid_axes", +] + +ExpResult = dict[str, Any] + + +class SavedGridPoint(TypedDict): + """The results at a point in the gridsearch. + + Args: + params: the full list of parameters identifying this variant + pind: the full index in the series' grid + data: the results of the experiment + """ + + params: dict + pind: tuple[int, ...] + data: ExpResult + + +class GridsearchResultDetails(TypedDict): + system: str + plot_data: list[SavedGridPoint] + series_data: dict[str, SeriesData] + metrics: tuple[str, ...] + grid_params: list[str] + plot_params: list[str] + grid_vals: list[Sequence] + main: float + + +@dataclass +class SeriesDef: + """The details of constructing the ragged axes of a grid search. + + The concept of a SeriesDef refers to a slice along a single axis of + a grid search in conjunction with another axis (or axes) + whose size or meaning differs along different slices. + + Attributes: + name: The name of the slice, as a label for printing + static_param: the constant parameter to this slice. Then key is + the name of the parameter, as understood by the experiment + Conceptually, the key serves as an index of this slice in + the gridsearch. + grid_params: the keys of the parameters in the experiment that + vary along jagged axis for this slice + grid_vals: the values of the parameters in the experiment that + vary along jagged axis for this slice + + Example: + + truck_wheels = SeriesDef( + "Truck", + {"vehicle": "flatbed_truck"}, + ["vehicle.n_wheels"], + [[10, 18]] + ) + + """ + + name: str + static_param: dict + grid_params: list[str] + grid_vals: list[Iterable] + + +@dataclass +class SeriesList: + """Specify the ragged slices of a grid search. + + As an example, consider a grid search of miles per gallon for + different vehicles, in different routes, with different tires. + Since different tires fit on different vehicles, the tire axis would + be ragged, varying along the vehicle axis. + + Truck = SeriesDef("trucks") + + Attributes: + param_name: the key of the parameter in the experiment that + varies along the series axis. + print_name: the print name of the parameter in the experiment + that varies along the series axis. + series_list: Each element of the series axis + + Example: + + truck_wheels = SeriesDef( + "Truck", + {"vehicle": "flatbed_truck"}, + ["vehicle.n_wheels"], + [[10, 18]] + ) + bike_tires = SeriesDef( + "Bike", + {"vehicle": "bicycle"}, + ["vehicle.tires"], + [["gravel_tires", "road_tires"]] + ) + VehicleOptions = SeriesList( + "vehicle", + "Vehicle Types", + [truck_wheels, bike_tires] + ) + + """ + + param_name: Optional[str] + print_name: Optional[str] + series_list: list[SeriesDef] diff --git a/src/gen_experiments/plotting.py b/src/gen_experiments/plotting.py index 0aa09bd..4a55e05 100644 --- a/src/gen_experiments/plotting.py +++ b/src/gen_experiments/plotting.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field -from types import EllipsisType as ellipsis -from typing import Annotated, Callable, Collection, Literal, Mapping, Sequence +from typing import Annotated, Callable, Literal, Sequence import matplotlib.pyplot as plt import numpy as np @@ -8,6 +7,8 @@ import seaborn as sns from matplotlib.axes import Axes +from .gridsearch.typing import GridLocator + PAL = sns.color_palette("Set1") PLOT_KWS = {"alpha": 0.7, "linewidth": 3} @@ -29,10 +30,7 @@ class _PlotPrefs: plot: bool = True rel_noise: Literal[False] | Callable = False - grid_params_match: Collection[dict] = field(default_factory=lambda: ()) - grid_ind_match: Collection[tuple[int | slice, int]] | ellipsis = field( - default_factory=lambda: ... - ) + plot_match: GridLocator = field(default_factory=lambda: GridLocator()) def __bool__(self): return self.plot @@ -40,9 +38,9 @@ def __bool__(self): def plot_coefficients( coefficients: Annotated[np.ndarray, "(n_coord, n_features)"], - input_features: Sequence[str] = None, - feature_names: Sequence[str] = None, - ax: bool = None, + input_features: Sequence[str], + feature_names: Sequence[str], + ax: Axes, **heatmap_kws, ): def detex(input: str) -> str: @@ -61,9 +59,6 @@ def detex(input: str) -> str: feature_names = [f"f{k}" for k in range(coefficients.shape[1])] with sns.axes_style(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}): - if ax is None: - fig, ax = plt.subplots(1, 1) - heatmap_args = { "xticklabels": input_features, "yticklabels": feature_names, @@ -85,8 +80,8 @@ def detex(input: str) -> str: def compare_coefficient_plots( coefficients_est: Annotated[np.ndarray, "(n_coord, n_feat)"], coefficients_true: Annotated[np.ndarray, "(n_coord, n_feat)"], - input_features: Sequence[str] = None, - feature_names: Sequence[str] = None, + input_features: Sequence[str], + feature_names: Sequence[str], ): """Create plots of true and estimated coefficients.""" n_cols = len(coefficients_est) @@ -203,6 +198,8 @@ def plot_training_data(x_train: np.ndarray, x_true: np.ndarray, x_smooth: np.nda ax0 = fig.add_subplot(1, 2, 1) elif x_train.shape[-1] == 3: ax0 = fig.add_subplot(1, 2, 1, projection="3d") + else: + raise ValueError("Too many or too few coordinates to plot") plot_training_trajectory(ax0, x_train, x_true, x_smooth) ax0.legend() ax0.set(title="Training data") @@ -276,7 +273,7 @@ def _plot_test_sim_data_3d( def plot_test_trajectories( x_test: np.ndarray, x_sim: np.ndarray, t_test: np.ndarray, t_sim: np.ndarray -) -> Mapping[str, np.ndarray]: +) -> None: """Plot a test trajectory Args: @@ -288,19 +285,17 @@ def plot_test_trajectories( A dict with two keys, "t_sim" (the simulation times) and "x_sim" (the simulated trajectory) """ - fig, axs = plt.subplots(x_test.shape[1], 1, sharex=True, figsize=(7, 9)) + _, axs = plt.subplots(x_test.shape[1], 1, sharex=True, figsize=(7, 9)) plt.suptitle("Test Trajectories by Dimension") plot_test_sim_data_1d_panel(axs, x_test, x_sim, t_test, t_sim) axs[-1].legend() plt.suptitle("Full Test Trajectories") if x_test.shape[1] == 2: - fig, axs = plt.subplots(1, 2, figsize=(10, 4.5)) + _, axs = plt.subplots(1, 2, figsize=(10, 4.5)) _plot_test_sim_data_2d(axs, x_test, x_sim) elif x_test.shape[1] == 3: - fig, axs = plt.subplots( - 1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"} - ) + _, axs = plt.subplots(1, 2, figsize=(10, 4.5), subplot_kw={"projection": "3d"}) _plot_test_sim_data_3d(axs, x_test, x_sim) else: raise ValueError("Can only plot 2d or 3d data.") diff --git a/src/gen_experiments/typing.py b/src/gen_experiments/typing.py new file mode 100644 index 0000000..815ccd8 --- /dev/null +++ b/src/gen_experiments/typing.py @@ -0,0 +1,65 @@ +from collections import defaultdict +from typing import TypeVar + +import numpy as np +from numpy.typing import NBitBase + +NpFlt = np.dtype[np.floating[NBitBase]] +Float1D = np.ndarray[tuple[int], NpFlt] +Float2D = np.ndarray[tuple[int, int], NpFlt] +Shape = TypeVar("Shape", bound=tuple[int, ...]) +FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] + + +class NestedDict(defaultdict): + """A dictionary that splits all keys by ".", creating a sub-dict. + + Args: see superclass + + Example: + + >>> foo = NestedDict("a.b"=1) + >>> foo["a.c"] = 2 + >>> foo["a"]["b"] + 1 + """ + + def __missing__(self, key): + try: + prefix, subkey = key.split(".", 1) + except ValueError: + raise KeyError(key) + return self[prefix][subkey] + + def __setitem__(self, key, value): + if "." in key: + prefix, suffix = key.split(".", 1) + if self.get(prefix) is None: + self[prefix] = NestedDict() + return self[prefix].__setitem__(suffix, value) + else: + return super().__setitem__(key, value) + + def update(self, other: dict): # type: ignore + try: + for k, v in other.items(): + self.__setitem__(k, v) + except: # noqa: E722 + super().update(other) + + def flatten(self): + """Flattens a nested dictionary without mutating. Returns new dict""" + + def _flatten(nested_d: dict) -> dict: + new = {} + for key, value in nested_d.items(): + if not isinstance(key, str): + raise TypeError("Only string keys allowed in flattening") + if not isinstance(value, dict): + new[key] = value + continue + for sub_key, sub_value in _flatten(value).items(): + new[key + "." + sub_key] = sub_value + return new + + return _flatten(self) diff --git a/src/gen_experiments/utils.py b/src/gen_experiments/utils.py index 49eddae..800b178 100644 --- a/src/gen_experiments/utils.py +++ b/src/gen_experiments/utils.py @@ -1,18 +1,5 @@ -from collections import defaultdict -from collections.abc import Iterable -from dataclasses import dataclass from itertools import chain -from types import EllipsisType as ellipsis -from typing import ( - Annotated, - Any, - Collection, - Optional, - Sequence, - TypedDict, - TypeVar, - cast, -) +from typing import Annotated, TypedDict, cast from warnings import warn import auto_ks as aks @@ -21,13 +8,9 @@ import pysindy as ps import sklearn import sklearn.metrics -from numpy.typing import DTypeLike, NBitBase, NDArray +from numpy.typing import NDArray -NpFlt = np.dtype[np.floating[NBitBase]] -Float1D = np.ndarray[tuple[int], NpFlt] -Float2D = np.ndarray[tuple[int, int], NpFlt] -Shape = TypeVar("Shape", bound=tuple[int, ...]) -FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] +from .typing import Float1D, Float2D, FloatND class SINDyTrialData(TypedDict): @@ -56,36 +39,6 @@ class FullSINDyTrialData(SINDyTrialData): x_sim: np.ndarray -class SavedData(TypedDict): - params: dict - pind: tuple[int] - data: SINDyTrialData | FullSINDyTrialData - - -T = TypeVar("T", bound=np.generic) -GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] -SeriesData = Annotated[ - list[ - tuple[ - Annotated[GridsearchResult, "metrics"], - Annotated[GridsearchResult[np.void], "arg_opts"], - ] - ], - "len=n_grid_axes", -] - - -class GridsearchResultDetails(TypedDict): - system: str - plot_data: list[SavedData] - series_data: dict[str, SeriesData] - metrics: list[str] - grid_params: list[str] - grid_vals: list[Sequence] - grid_axes: dict[str, Collection[float]] - main: float - - def diff_lookup(kind): normalized_kind = kind.lower().replace(" ", "") if normalized_kind == "finitedifference": @@ -264,141 +217,6 @@ def simulate_test_data(model: ps.SINDy, dt: float, x_test: Float2D) -> SINDyTria return {"t_sim": t_sim, "x_sim": x_sim, "t_test": t_test} -@dataclass -class SeriesDef: - """The details of constructing the ragged axes of a grid search. - - The concept of a SeriesDef refers to a slice along a single axis of - a grid search in conjunction with another axis (or axes) - whose size or meaning differs along different slices. - - Attributes: - name: The name of the slice, as a label for printing - static_param: the constant parameter to this slice. Then key is - the name of the parameter, as understood by the experiment - Conceptually, the key serves as an index of this slice in - the gridsearch. - grid_params: the keys of the parameters in the experiment that - vary along jagged axis for this slice - grid_vals: the values of the parameters in the experiment that - vary along jagged axis for this slice - - Example: - - truck_wheels = SeriesDef( - "Truck", - {"vehicle": "flatbed_truck"}, - ["vehicle.n_wheels"], - [[10, 18]] - ) - - """ - - name: str - static_param: dict - grid_params: Optional[Sequence[str]] - grid_vals: Optional[list[Sequence]] - - -@dataclass -class SeriesList: - """Specify the ragged slices of a grid search. - - As an example, consider a grid search of miles per gallon for - different vehicles, in different routes, with different tires. - Since different tires fit on different vehicles, the tire axis would - be ragged, varying along the vehicle axis. - - Truck = SeriesDef("trucks") - - Attributes: - param_name: the key of the parameter in the experiment that - varies along the series axis. - print_name: the print name of the parameter in the experiment - that varies along the series axis. - series_list: Each element of the series axis - - Example: - - truck_wheels = SeriesDef( - "Truck", - {"vehicle": "flatbed_truck"}, - ["vehicle.n_wheels"], - [[10, 18]] - ) - bike_tires = SeriesDef( - "Bike", - {"vehicle": "bicycle"}, - ["vehicle.tires"], - [["gravel_tires", "road_tires"]] - ) - VehicleOptions = SeriesList( - "vehicle", - "Vehicle Types", - [truck_wheels, bike_tires] - ) - - """ - - param_name: Optional[str] - print_name: Optional[str] - series_list: list[SeriesDef] - - -class NestedDict(defaultdict): - """A dictionary that splits all keys by ".", creating a sub-dict. - - Args: see superclass - - Example: - - >>> foo = NestedDict("a.b"=1) - >>> foo["a.c"] = 2 - >>> foo["a"]["b"] - 1 - """ - - def __missing__(self, key): - try: - prefix, subkey = key.split(".", 1) - except ValueError: - raise KeyError(key) - return self[prefix][subkey] - - def __setitem__(self, key, value): - if "." in key: - prefix, suffix = key.split(".", 1) - if self.get(prefix) is None: - self[prefix] = NestedDict() - return self[prefix].__setitem__(suffix, value) - else: - return super().__setitem__(key, value) - - def update(self, other: dict): # type: ignore - try: - for k, v in other.items(): - self.__setitem__(k, v) - except: # noqa: E722 - super().update(other) - - def flatten(self): - """Flattens a nested dictionary without mutating. Returns new dict""" - - def _flatten(nested_d: dict) -> dict: - new = {} - for key, value in nested_d.items(): - if not isinstance(key, str): - raise TypeError("Only string keys allowed in flattening") - if not isinstance(value, dict): - new[key] = value - continue - for sub_key, sub_value in _flatten(value).items(): - new[key + "." + sub_key] = sub_value - return new - - return _flatten(self) - - def kalman_generalized_cv( times: np.ndarray, measurements: np.ndarray, alpha0: float = 1, detail=False ): @@ -435,192 +253,3 @@ def proj(curr_params, t): est_Q = np.linalg.inv(params.W_neg_sqrt @ params.W_neg_sqrt.T) est_alpha = 1 / (est_Q / Qi).mean() return est_alpha - - -def _amax_to_full_inds( - amax_inds: Collection[tuple[int | slice, int] | ellipsis] | ellipsis, - amax_arrays: list[list[GridsearchResult[np.void]]], -) -> set[tuple[int, ...]]: - """Find full indexers to selected elements of argmax arrays - - Args: - amax_inds: selection statemtent of which argmaxes to return. - amax_arrays: arrays of indexes to full gridsearch that are responsible for - the computed max values. First level of nesting reflects series(?), second - level reflects which grid grid axis. - Returns: - all indexers to full gridsearch that are requested by amax_inds - """ - - def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]: - return tuple(int(el) for el in cast(Iterable, tuple_like)) - - if amax_inds is ...: # grab each element from arrays in list of lists of arrays - return { - np_to_primitive(el) - for ar_list in amax_arrays - for arr in ar_list - for el in arr.flatten() - } - all_inds = set() - for plot_axis_results in [el for series in amax_arrays for el in series]: - for ind in amax_inds: - if ind is ...: # grab each element from arrays in list of lists of arrays - all_inds |= { - np_to_primitive(el) - for ar_list in amax_arrays - for arr in ar_list - for el in arr.flatten() - } - elif isinstance(ind[0], int): - all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))} - else: # ind[0] is slice(None) - all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]} - return all_inds - - -def _argopt( - arr: FloatND, axis: Optional[int | tuple[int, ...]] = None, opt: str = "max" -) -> NDArray[np.void]: - """Calculate the argmax/min, but accept tuple axis. - - Ignores NaN values - - Args: - arr: an array to search - axis: The axis or axes to search through for the argmax/argmin. - opt: One of {"max", "min"} - - Returns: - array of indices for the argopt. If m = arr.ndim and n = len(axis), - the final result will be an array of ndim = m-n with elements being - tuples of length m - """ - dtype: DTypeLike = [(f"f{axind}", "i") for axind in range(arr.ndim)] - if axis is None: - axis = () - axis = (axis,) if isinstance(axis, int) else axis - keep_axes = tuple(sorted(set(range(arr.ndim)) - set(axis))) - keep_shape = tuple(arr.shape[ax] for ax in keep_axes) - result = np.empty(keep_shape, dtype=dtype) - optfun = np.nanargmax if opt == "max" else np.nanargmin - for slise in np.ndindex(keep_shape): - sub_arr = arr - # since we shrink shape, we need to chop of axes from the end - for ind, ax in zip(reversed(slise), reversed(keep_axes)): - sub_arr = np.take(sub_arr, ind, ax) - subind_max = np.unravel_index(optfun(sub_arr), sub_arr.shape) - fullind_max = np.empty((arr.ndim), int) - fullind_max[np.array(keep_axes, int)] = slise - fullind_max[np.array(axis, int)] = subind_max - result[slise] = tuple(fullind_max) - return result - - -def _index_in(base: tuple[int, ...], tgt: tuple[int | ellipsis | slice, ...]) -> bool: - """Determine whether base indexing tuple will match given numpy index""" - if len(base) > len(tgt): - return False - curr_ax = 0 - for ax, ind in enumerate(tgt): - if isinstance(ind, int): - try: - if ind != base[curr_ax]: - return False - except IndexError: - return False - elif isinstance(ind, slice): - if not (ind.start is None and ind.stop is None and ind.step is None): - raise ValueError("Only slices allowed are `slice(None)`") - elif ind is ...: - base_ind_remaining = len(base) - curr_ax - tgt_ind_remaining = len(tgt) - ax - # ellipsis can take 0 or more spots - curr_ax += max(base_ind_remaining - tgt_ind_remaining, -1) - curr_ax += 1 - if curr_ax == len(base): - return True - return False - - -def _grid_locator_match( - exp_params: dict[str, Any], - exp_ind: tuple[int, ...], - param_spec: Collection[dict[str, Any]], - ind_spec: Collection[tuple[int, ...]], -) -> bool: - """Determine whether experimental parameters match a specification - - Logical clause applied is: - - OR((exp_params MATCHES params for params in param_spec)) - AND - OR((exp_ind MATCHES ind for ind in ind_spec)) - - Treats OR of an empty collection as falsy - Args: - exp_params: the experiment parameters to evaluate - exp_ind: the experiemnt's full-size grid index to evaluate - param_spec: the criteria for matching exp_params - ind_spec: the criteria for matching exp_ind - """ - found_match = False - for params_or in param_spec: - params_or = {k: _param_normalize(v) for k, v in params_or.items()} - - try: - if all( - _param_normalize(exp_params[param]) == value - for param, value in params_or.items() - ): - found_match = True - break - except KeyError: - pass - for ind_or in ind_spec: - # exp_ind doesn't include metric, so skip first metric - if _index_in(exp_ind, ind_or[1:]): - break - else: - return False - return found_match - - -def strict_find_grid_match( - results: GridsearchResultDetails, - *, - params: Optional[dict[str, Any]] = None, - ind_spec: Optional[tuple[int | slice, int] | ellipsis] = None, -) -> SINDyTrialData: - if params is None: - params = {} - if ind_spec is None: - ind_spec = ... - matches = [] - amax_arrays = [ - [single_ser_and_axis[1] for single_ser_and_axis in single_series_all_axes] - for _, single_series_all_axes in results["series_data"].items() - ] - full_inds = _amax_to_full_inds((ind_spec,), amax_arrays) - - for trajectory in results["plot_data"]: - if _grid_locator_match( - trajectory["params"], trajectory["pind"], (params,), full_inds - ): - matches.append(trajectory) - - if len(matches) > 1: - raise ValueError("Specification is nonunique; matched multiple results") - if len(matches) == 0: - raise ValueError("Could not find a match") - return matches[0]["data"] - - -_EqTester = TypeVar("_EqTester") - - -def _param_normalize(val: _EqTester) -> _EqTester | str: - if type(val).__eq__ == object.__eq__: - return repr(val) - else: - return val diff --git a/tests/test_all.py b/tests/test_all.py index 962d0df..13274fd 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,187 +1,21 @@ -import numpy as np import pytest -from gen_experiments import gridsearch, utils - - -def test_thin_indexing(): - result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (-1,)))) - expected = { - (0, 0, 0), - (0, 1, 0), - (1, 0, 0), - (1, 1, 0), - (1, 0, 1), - (1, 1, 1), - } - assert result == expected - - -def test_thin_indexing_default(): - result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), None)) - expected = { - (0, 0, 0), - (0, 1, 0), - (1, 0, 0), - (1, 1, 0), - (0, 0, 1), - (0, 1, 1), - } - assert result == expected - - -def test_thin_indexing_callable(): - result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (lambda x: x,)))) - expected = { - (0, 0, 0), - (0, 1, 0), - (1, 0, 0), - (1, 1, 0), - (1, 0, 1), - (1, 1, 1), - } - assert result == expected - - -def test_curr_skinny_specs(): - grid_params = ["a", "c", "e", "f"] - skinny_specs = ( - ("a", "b", "c", "d", "e"), - ((1, 2, 3, 4), (0, 2, 3, 4), (0, 1, 3, 4), (0, 1, 2, 4), (0, 1, 2, 3)), - ) - ind_skinny, where_others = gridsearch._curr_skinny_specs(skinny_specs, grid_params) - assert ind_skinny == [0, 1, 2] - assert where_others == ((2, 4), (0, 4), (0, 2)) - - -def test_marginalize_grid_views(): - arr = np.arange(16, dtype=np.float_).reshape( - 2, 2, 2, 2 - ) # (metrics, param1, param2, param3) - arr[0, 0, 0, 0] = 1000 - arr[-1, -1, -1, 0] = -1000 - arr[0, 0, 0, 1] = np.nan - grid_decisions = ["plot", "max", "plot"] - opts = ["max", "min"] - res_val, res_ind = gridsearch._marginalize_grid_views(grid_decisions, arr, opts) - assert len(res_val) == len([dec for dec in grid_decisions if dec == "plot"]) - expected_val = [ - np.array([[1000, 7], [8, -1000]]), - np.array([[1000, 7], [-1000, 9]]), - ] - for result, expected in zip(res_val, expected_val): - np.testing.assert_array_equal(result, expected) - - ts = "i,i,i,i" - expected_ind = [ - np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 0, 0, 0), (1, 1, 1, 0)]], ts), - np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 1, 0), (1, 0, 0, 1)]], ts), - ] - for result, expected in zip(res_ind, expected_ind): - np.testing.assert_array_equal(result, expected) - - -def test_argopt_tuple_axis(): - arr = np.arange(16).reshape(2, 2, 2, 2) - arr[0, 0, 0, 0] = 1000 - result = utils._argopt(arr, (1, 3)) - expected = np.array( - [[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 0, 1), (1, 1, 1, 1)]], dtype="i,i,i,i" - ) - np.testing.assert_array_equal(result, expected) - - -def test_argopt_empty_tuple_axis(): - arr = np.arange(4).reshape(4) - result = utils._argopt(arr, ()) - expected = np.array([(0,), (1,), (2,), (3,)], dtype=[("f0", "i")]) - np.testing.assert_array_equal(result, expected) - result = utils._argopt(arr, None) - pass - - -def test_argopt_int_axis(): - arr = np.arange(8).reshape(2, 2, 2) - arr[0, 0, 0] = 1000 - result = utils._argopt(arr, 1) - expected = np.array([[(0, 0, 0), (0, 1, 1)], [(1, 1, 0), (1, 1, 1)]], dtype="i,i,i") - np.testing.assert_array_equal(result, expected) - - -def test_index_in(): - match_me = (1, ..., slice(None), 3) - good = [(1, 2, 1, 3), (1, 1, 3)] - for g in good: - assert utils._index_in(g, match_me) - bad = [(1, 3), (1, 1, 2), (1, 1, 1, 2)] - for b in bad: - assert not utils._index_in(b, match_me) - - -def test_index_in_errors(): - with pytest.raises(ValueError): - utils._index_in((1,), (slice(-1),)) +from gen_experiments.typing import NestedDict def test_flatten_nested_dict(): - deep = utils.NestedDict(a=utils.NestedDict(b=1)) + deep = NestedDict(a=NestedDict(b=1)) result = deep.flatten() assert deep != result expected = {"a.b": 1} assert result == expected -def test_grid_locator_match(): - m_params = {"sim_params.t_end": 10, "foo": 1} - m_ind = (0, 1) - # Effectively testing the clause: (x OR y OR ...) AND (a OR b OR ...) - # Note: OR() with no args is falsy - # also note first index is stripped ind_spec - good_specs = [ - (({"sim_params.t_end": 10},), ((1, 0, 1),)), - (({"sim_params.t_end": 10},), ((1, 0, 1), (1, 0, ...))), - (({"sim_params.t_end": 10}, {"foo": 1}), ((1, 0, 1),)), - (({"sim_params.t_end": 10}, {"bar: 1"}), ((1, 0, 1),)), - ( - ({"sim_params.t_end": 10},), - ( - (1, 0, 1), - ( - 1, - 1, - ), - ), - ), - ] - for param_spec, ind_spec in good_specs: - assert utils._grid_locator_match(m_params, m_ind, param_spec, ind_spec) - - bad_specs = [ - ((), ((0, 1),)), - (({"sim_params.t_end": 10},), ()), - (({"sim_params.t_end": 9},), ((1, 0, 1),)), - (({"sim_params.t_end": 10},), ((1, 0, 0),)), - ((), ()), - ] - for param_spec, ind_spec in bad_specs: - assert not utils._grid_locator_match(m_params, m_ind, param_spec, ind_spec) - - -def test_amax_to_full_inds(): - amax_inds = ((1, 1), (slice(None), 0)) - arr = np.array([[(0, 0), (0, 1)], [(1, 0), (1, 1)]], dtype="i,i") - amax_arrays = [[arr, arr], [arr]] - result = gridsearch._amax_to_full_inds(amax_inds, amax_arrays) - expected = {(0, 0), (1, 1), (1, 0)} - assert result == expected - result = gridsearch._amax_to_full_inds(..., amax_arrays) - expected |= {(0, 1)} - return result == expected - - def test_flatten_nested_bad_dict(): + nested = {1: NestedDict(b=1)} + # Testing the very thing that causes a typing error, thus ignoring with pytest.raises(TypeError, match="keywords must be strings"): - utils.NestedDict(**{1: utils.NestedDict(b=1)}) + NestedDict(**nested) # type: ignore with pytest.raises(TypeError, match="Only string keys allowed"): - deep = utils.NestedDict(a={1: 1}) + deep = NestedDict(a={1: 1}) deep.flatten() diff --git a/tests/test_gridsearch.py b/tests/test_gridsearch.py new file mode 100644 index 0000000..6d45ba9 --- /dev/null +++ b/tests/test_gridsearch.py @@ -0,0 +1,238 @@ +import numpy as np +import pysindy as ps +import pytest + +from gen_experiments import gridsearch +from gen_experiments.gridsearch.typing import ( + GridsearchResultDetails, + SavedGridPoint, + SeriesData, +) + + +def test_thin_indexing(): + result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (-1,)))) + expected = { + (0, 0, 0), + (0, 1, 0), + (1, 0, 0), + (1, 1, 0), + (1, 0, 1), + (1, 1, 1), + } + assert result == expected + + +def test_thin_indexing_default(): + result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), None)) + expected = { + (0, 0, 0), + (0, 1, 0), + (1, 0, 0), + (1, 1, 0), + (0, 0, 1), + (0, 1, 1), + } + assert result == expected + + +def test_thin_indexing_callable(): + result = set(gridsearch._ndindex_skinny((2, 2, 2), (0, 2), ((0,), (lambda x: x,)))) + expected = { + (0, 0, 0), + (0, 1, 0), + (1, 0, 0), + (1, 1, 0), + (1, 0, 1), + (1, 1, 1), + } + assert result == expected + + +def test_curr_skinny_specs(): + grid_params = ["a", "c", "e", "f"] + skinny_specs = ( + ("a", "b", "c", "d", "e"), + ((1, 2, 3, 4), (0, 2, 3, 4), (0, 1, 3, 4), (0, 1, 2, 4), (0, 1, 2, 3)), + ) + ind_skinny, where_others = gridsearch._curr_skinny_specs(skinny_specs, grid_params) + assert ind_skinny == [0, 1, 2] + assert where_others == ((2, 4), (0, 4), (0, 2)) + + +def test_marginalize_grid_views(): + arr = np.arange(16, dtype=np.float_).reshape( + 2, 2, 2, 2 + ) # (metrics, param1, param2, param3) + arr[0, 0, 0, 0] = 1000 + arr[-1, -1, -1, 0] = -1000 + arr[0, 0, 0, 1] = np.nan + grid_decisions = ["plot", "max", "plot"] + opts = ["max", "min"] + res_val, res_ind = gridsearch._marginalize_grid_views(grid_decisions, arr, opts) + assert len(res_val) == len([dec for dec in grid_decisions if dec == "plot"]) + expected_val = [ + np.array([[1000, 7], [8, -1000]]), + np.array([[1000, 7], [-1000, 9]]), + ] + for result, expected in zip(res_val, expected_val): + np.testing.assert_array_equal(result, expected) + + ts = "i,i,i,i" + expected_ind = [ + np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 0, 0, 0), (1, 1, 1, 0)]], ts), + np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 1, 0), (1, 0, 0, 1)]], ts), + ] + for result, expected in zip(res_ind, expected_ind): + np.testing.assert_array_equal(result, expected) + + +def test_argopt_tuple_axis(): + arr = np.arange(16, dtype=np.float_).reshape(2, 2, 2, 2) + arr[0, 0, 0, 0] = 1000 + result = gridsearch._argopt(arr, (1, 3)) + expected = np.array( + [[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 0, 1), (1, 1, 1, 1)]], dtype="i,i,i,i" + ) + np.testing.assert_array_equal(result, expected) + + +def test_argopt_empty_tuple_axis(): + arr = np.arange(4, dtype=np.float_).reshape(4) + result = gridsearch._argopt(arr, ()) + expected = np.array([(0,), (1,), (2,), (3,)], dtype=[("f0", "i")]) + np.testing.assert_array_equal(result, expected) + result = gridsearch._argopt(arr, None) + pass + + +def test_argopt_int_axis(): + arr = np.arange(8, dtype=np.float_).reshape(2, 2, 2) + arr[0, 0, 0] = 1000 + result = gridsearch._argopt(arr, 1) + expected = np.array([[(0, 0, 0), (0, 1, 1)], [(1, 1, 0), (1, 1, 1)]], dtype="i,i,i") + np.testing.assert_array_equal(result, expected) + + +def test_index_in(): + match_me = (1, ..., slice(None), 3) + good = [(1, 2, 1, 3), (1, 1, 3)] + for g in good: + assert gridsearch._index_in(g, match_me) + bad = [(1, 3), (1, 1, 2), (1, 1, 1, 2)] + for b in bad: + assert not gridsearch._index_in(b, match_me) + + +def test_index_in_errors(): + with pytest.raises(ValueError): + gridsearch._index_in((1,), (slice(-1),)) + + +def test_amax_to_full_inds(): + amax_inds = ((1, 1), (slice(None), 0)) + arr = np.array([[(0, 0), (0, 1)], [(1, 0), (1, 1)]], dtype="i,i") + amax_arrays = [[arr, arr], [arr]] + result = gridsearch._amax_to_full_inds(amax_inds, amax_arrays) + expected = {(0, 0), (1, 1), (1, 0)} + assert result == expected + result = gridsearch._amax_to_full_inds(..., amax_arrays) + expected |= {(0, 1)} + return result == expected + + +@pytest.fixture +def gridsearch_results(): + want: SavedGridPoint = { + "params": {"diff_params.alpha": 0.1, "opt_params": ps.STLSQ()}, + "pind": (1,), + "data": {}, + } + dont_want: SavedGridPoint = { + "params": {"diff_params.alpha": 0.2, "opt_params": ps.SSR()}, + "pind": (0,), + "data": {}, + } + tup_dtype = np.dtype([("f0", "i")]) + max_amax: SeriesData = [ + (np.ones((2, 2)), np.array([[(1,), (0)], [(0,), (0,)]], dtype=tup_dtype)), + (np.ones((2, 2)), np.array([[(0,), (0,)], [(0,), (0,)]], dtype=tup_dtype)), + ] + full_details: GridsearchResultDetails = { + "system": "sho", + "plot_data": [want, dont_want], + "series_data": {"foo": max_amax}, + "metrics": ("mse", "mae"), + "grid_params": ["sim_params.t_end", "sim_params.noise"], + "plot_params": ["sim_params.t_end", "sim_params.noise"], + "grid_vals": [[1, 2], [5, 6]], + "main": 1, + } + return want, full_details + + +@pytest.mark.parametrize( + "locator", + ( + gridsearch.GridLocator( + ("mse",), (("sim_params.t_end",), ...), [{"diff_params.alpha": 0.1}] + ), + gridsearch.GridLocator( + ("mse",), (("sim_params.t_end",), ...), [{"opt_params": ps.STLSQ()}] + ), + gridsearch.GridLocator( + ..., (..., ...), [{"diff_params.alpha": lambda x: x < 0.2}] + ), + gridsearch.GridLocator(("mse",), {("sim_params.t_end", (0,))}, [{}]), + gridsearch.GridLocator( + ..., (..., ...), [{"diff_params.alpha": 0.1}, {"diff_params.alpha": 0.3}] + ), + ), + ids=("exact", "object", "callable", "by_axis", "or"), +) +def test_find_gridpoints(gridsearch_results, locator): + want, full_details = gridsearch_results + results = gridsearch.find_gridpoints( + locator, full_details["plot_data"], full_details + ) + assert [want] == results + + +def test_grid_locator_match(): + m_params = {"sim_params.t_end": 10, "foo": 1} + m_ind = (0, 1) + # Effectively testing the clause: (x OR y OR ...) AND (a OR b OR ...) + # Note: OR() with no args is falsy, AND() with no args is thruthy + # also note first index is stripped ind_spec + good_specs = [ + (({"sim_params.t_end": 10},), ((1, 0, 1),)), + (({"sim_params.t_end": 10},), ((1, 0, 1), (1, 0, ...))), + (({"sim_params.t_end": 10}, {"foo": 1}), ((1, 0, 1),)), + (({"sim_params.t_end": 10}, {"bar: 1"}), ((1, 0, 1),)), + (({"sim_params.t_end": 10},), ((1, 0, 1), (1, 1))), + ] + for param_spec, ind_spec in good_specs: + assert gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) + + bad_specs = [ + ((), ((0, 1),)), + (({"sim_params.t_end": 10},), ()), + (({"sim_params.t_end": 9},), ((1, 0, 1),)), + (({"sim_params.t_end": 10},), ((1, 0, 0),)), + ((), ()), + ] + for param_spec, ind_spec in bad_specs: + assert not gridsearch._grid_locator_match(m_params, m_ind, param_spec, ind_spec) + + +def test_gridsearch_mock(): + results = gridsearch.run( + 1, + "none", + grid_params=["foo"], + grid_vals=[[0, 1]], + grid_decisions=["plot"], + other_params={"bar": False}, + metrics=("mse", "mae"), + ) + assert len(results["plot_data"]) == 0