Skip to content

Commit

Permalink
fix: Make track sim_params correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Oct 1, 2024
1 parent 1d12252 commit e3e378a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/gen_experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class NoExperiment:
name = "No Experiment"
lookup_dict = {"arg": {"foo": 1}}

@staticmethod
def gen_data(*args: Any, **kwargs: Any) -> dict[str, Any]:
return {}

@staticmethod
def run(
*args: Any, return_all: bool = True, **kwargs: Any
Expand Down
6 changes: 5 additions & 1 deletion src/gen_experiments/gridsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def run(
elif base_ex.__name__ == "gen_experiments.pdes":
plot_panel = plot_pde_panel
data_step = gen_pde_data
elif base_ex.__name__ == "NoExperiment":
data_step = gen_experiments.NoExperiment.gen_data
if series_params is None:
series_params = SeriesList(None, None, [SeriesDef(group, {}, [], [])])
legends = False
Expand Down Expand Up @@ -220,10 +222,12 @@ def run(
start = process_time()
for axis_ind, key, val_list in zip(ind, new_grid_params, new_grid_vals):
curr_other_params[key] = val_list[axis_ind]
data = data_step(seed=seed, **curr_other_params.pop("sim_params"))
sim_params = curr_other_params.pop("sim_params", {})
data = data_step(seed=seed, **sim_params)
curr_results, grid_data = base_ex.run(
data, **curr_other_params, display=False, return_all=True
)
curr_results["sim_params"] = sim_params
intermediate_data.append(
{"params": curr_other_params.flatten(), "pind": ind, "data": grid_data}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_gridsearch_mock():
grid_params=["foo"],
grid_vals=[[0, 1]],
grid_decisions=["plot"],
other_params={"bar": False},
other_params={"bar": False, "sim_params": {}},
metrics=("mse", "mae"),
)
assert len(results["plot_data"]) == 0

0 comments on commit e3e378a

Please sign in to comment.