From 139470ee4b1374a06dd0a56f2d631d0ce15c606f Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:59:47 -0700 Subject: [PATCH] fix(gridsearch): link finding gridpoints to series --- .gitignore | 1 + src/gen_experiments/gridsearch/__init__.py | 51 ++++++++++++---------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index a867520..6ff01f9 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ gen_experiments/trials/ scratch/ *.png debugme*.py +trials/* # IDE files .vscode diff --git a/src/gen_experiments/gridsearch/__init__.py b/src/gen_experiments/gridsearch/__init__.py index b65901b..d66a324 100644 --- a/src/gen_experiments/gridsearch/__init__.py +++ b/src/gen_experiments/gridsearch/__init__.py @@ -4,17 +4,7 @@ from logging import getLogger from pprint import pformat from types import EllipsisType as ellipsis -from typing import ( - Annotated, - Any, - Callable, - Collection, - Optional, - Sequence, - TypeVar, - Union, - cast, -) +from typing import Annotated, Any, Callable, Optional, Sequence, TypeVar, Union, cast from warnings import warn import matplotlib.pyplot as plt @@ -24,12 +14,12 @@ from scipy.stats import kstest import gen_experiments + from .. import config from ..odes import plot_ode_panel from ..plotting import _PlotPrefs from ..typing import FloatND, NestedDict from ..utils import simulate_test_data - from .typing import ( ExpResult, GridLocator, @@ -178,6 +168,7 @@ def run( docstring for _ndindex_skinny). By default, all plot axes are made skinny with respect to each other. """ + logger.info(f"Beginning gridsearch of system: {group}") other_params = NestedDict(**other_params) base_ex, base_group = gen_experiments.experiments[group] if series_params is None: @@ -213,7 +204,10 @@ def run( full_results_shape[1:], ind_skinny, where_others ) for ind_counter, ind in enumerate(gridpoint_selector): - print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r") + logger.info( + f"Calculating series {s_counter} ({series_data.name}), " + f"gridpoint {ind_counter} ({ind})" + ) for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals): curr_other_params[key] = val_list[axis_ind] curr_results, grid_data = base_ex.run( @@ -256,13 +250,26 @@ def run( ), } if plot_prefs: - plot_data = find_gridpoints( - plot_prefs.plot_match, - intermediate_data, - results["series_data"].values(), - results["metrics"], - results["scan_grid"] - ) + plot_data = [] + # todo - improve how plot_prefs.plot_match interacts with series + # This is a horrible hack, assuming a params_or for each series, ino + for series_data, params in zip( + series_params.series_list, + list(plot_prefs.plot_match.params_or), + strict=True, + ): + key = series_data.name + logger.info(f"Searching for matching points in series: {key}") + locator = GridLocator( + plot_prefs.plot_match.metrics, plot_prefs.plot_match.keep_axes, [params] + ) + plot_data += find_gridpoints( + locator, + intermediate_data, + [results["series_data"][key]], + results["metrics"], + results["scan_grid"], + ) results["plot_data"] = plot_data for gridpoint in plot_data: grid_data = gridpoint["data"] @@ -599,9 +606,7 @@ def find_gridpoints( if find.metrics is ...: inds_of_metrics = range(len(argopt_metrics)) else: - inds_of_metrics = tuple( - argopt_metrics.index(metric) for metric in find.metrics - ) + inds_of_metrics = tuple(argopt_metrics.index(metric) for metric in find.metrics) # No deduplication is done! keep_axes = _normalize_keep_axes(find.keep_axes, argopt_axes)