Skip to content

Commit

Permalink
Add grid_size_ and move attributes initialization to fit
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 17, 2024
1 parent 1e945ba commit 4682e53
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Changelog

* Renamed :class:`metalearners.grid_search._GSResult` to :class:`metalearners.grid_search.GSResult`.

* Added ``grid_size_`` attribute to :class:`metalearners.grid_search.MetaLearnerGridSearch`.


0.7.0 (2024-07-12)
------------------
Expand Down
15 changes: 7 additions & 8 deletions metalearners/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,6 @@ def __init__(
self.store_raw_results = store_raw_results
self.store_results = store_results

self.raw_results_: list[GSResult] | Generator[GSResult, None, None] | None = (
None
)
self.results_: pd.DataFrame | None = None

all_base_models = set(
metalearner_factory.nuisance_model_specifications().keys()
) | set(metalearner_factory.treatment_model_specifications().keys())
Expand Down Expand Up @@ -348,14 +343,18 @@ def fit(
metalerner_fit_params=kwargs,
)
)

self.grid_size_ = len(jobs)
self.raw_results_: list[GSResult] | Generator[GSResult, None, None] | None
self.results_: pd.DataFrame | None

return_as = "list" if self.store_raw_results else "generator_unordered"
parallel = Parallel(
n_jobs=self.n_jobs, verbose=self.verbose, return_as=return_as
)
raw_results = parallel(delayed(_fit_and_score)(job) for job in jobs)
self.raw_results_ = raw_results
self.raw_results = parallel(delayed(_fit_and_score)(job) for job in jobs)
if self.store_results:
self.results_ = _format_results(results=raw_results)
self.results_ = _format_results(results=self.raw_results)
if not self.store_raw_results:
# This just checks that the generator is empty
try:
Expand Down
1 change: 1 addition & 0 deletions tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def test_metalearnergridsearch_smoke(
assert gs.results_ is not None
assert gs.results_.shape[0] == expected_n_configs
assert gs.results_.index.names == expected_index_cols
assert gs.grid_size_ == expected_n_configs

train_scores_cols = set(
c[6:] for c in list(gs.results_.columns) if c.startswith("train_")
Expand Down

0 comments on commit 4682e53

Please sign in to comment.