Skip to content

Commit

Permalink
fix(gridsearch): link finding gridpoints to series
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Apr 23, 2024
1 parent 64d7e10 commit 139470e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ gen_experiments/trials/
scratch/
*.png
debugme*.py
trials/*

# IDE files
.vscode
Expand Down
51 changes: 28 additions & 23 deletions src/gen_experiments/gridsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from logging import getLogger
from pprint import pformat
from types import EllipsisType as ellipsis
from typing import (
Annotated,
Any,
Callable,
Collection,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from typing import Annotated, Any, Callable, Optional, Sequence, TypeVar, Union, cast
from warnings import warn

import matplotlib.pyplot as plt
Expand All @@ -24,12 +14,12 @@
from scipy.stats import kstest

import gen_experiments

from .. import config
from ..odes import plot_ode_panel
from ..plotting import _PlotPrefs
from ..typing import FloatND, NestedDict
from ..utils import simulate_test_data

from .typing import (
ExpResult,
GridLocator,
Expand Down Expand Up @@ -178,6 +168,7 @@ def run(
docstring for _ndindex_skinny). By default, all plot axes
are made skinny with respect to each other.
"""
logger.info(f"Beginning gridsearch of system: {group}")
other_params = NestedDict(**other_params)
base_ex, base_group = gen_experiments.experiments[group]
if series_params is None:
Expand Down Expand Up @@ -213,7 +204,10 @@ def run(
full_results_shape[1:], ind_skinny, where_others
)
for ind_counter, ind in enumerate(gridpoint_selector):
print(f"Calculating series {s_counter}, gridpoint{ind_counter}", end="\r")
logger.info(
f"Calculating series {s_counter} ({series_data.name}), "
f"gridpoint {ind_counter} ({ind})"
)
for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals):
curr_other_params[key] = val_list[axis_ind]
curr_results, grid_data = base_ex.run(
Expand Down Expand Up @@ -256,13 +250,26 @@ def run(
),
}
if plot_prefs:
plot_data = find_gridpoints(
plot_prefs.plot_match,
intermediate_data,
results["series_data"].values(),
results["metrics"],
results["scan_grid"]
)
plot_data = []
# todo - improve how plot_prefs.plot_match interacts with series
# This is a horrible hack, assuming a params_or for each series, ino
for series_data, params in zip(
series_params.series_list,
list(plot_prefs.plot_match.params_or),
strict=True,
):
key = series_data.name
logger.info(f"Searching for matching points in series: {key}")
locator = GridLocator(
plot_prefs.plot_match.metrics, plot_prefs.plot_match.keep_axes, [params]
)
plot_data += find_gridpoints(
locator,
intermediate_data,
[results["series_data"][key]],
results["metrics"],
results["scan_grid"],
)
results["plot_data"] = plot_data
for gridpoint in plot_data:
grid_data = gridpoint["data"]
Expand Down Expand Up @@ -599,9 +606,7 @@ def find_gridpoints(
if find.metrics is ...:
inds_of_metrics = range(len(argopt_metrics))
else:
inds_of_metrics = tuple(
argopt_metrics.index(metric) for metric in find.metrics
)
inds_of_metrics = tuple(argopt_metrics.index(metric) for metric in find.metrics)
# No deduplication is done!
keep_axes = _normalize_keep_axes(find.keep_axes, argopt_axes)

Expand Down

0 comments on commit 139470e

Please sign in to comment.