diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9095266..13a0dab 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,8 @@ Changelog * Renamed :class:`metalearners.grid_search._GSResult` to :class:`metalearners.grid_search.GSResult`. +* Added ``grid_size_`` attribute to :class:`metalearners.grid_search.MetaLearnerGridSearch`. + 0.7.0 (2024-07-12) ------------------ diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index 7dfa2ea..304beb5 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -225,11 +225,6 @@ def __init__( self.store_raw_results = store_raw_results self.store_results = store_results - self.raw_results_: list[GSResult] | Generator[GSResult, None, None] | None = ( - None - ) - self.results_: pd.DataFrame | None = None - all_base_models = set( metalearner_factory.nuisance_model_specifications().keys() ) | set(metalearner_factory.treatment_model_specifications().keys()) @@ -348,14 +343,18 @@ def fit( metalerner_fit_params=kwargs, ) ) + + self.grid_size_ = len(jobs) + self.raw_results_: list[GSResult] | Generator[GSResult, None, None] | None + self.results_: pd.DataFrame | None + return_as = "list" if self.store_raw_results else "generator_unordered" parallel = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, return_as=return_as ) - raw_results = parallel(delayed(_fit_and_score)(job) for job in jobs) - self.raw_results_ = raw_results + self.raw_results = parallel(delayed(_fit_and_score)(job) for job in jobs) if self.store_results: - self.results_ = _format_results(results=raw_results) + self.results_ = _format_results(results=self.raw_results) if not self.store_raw_results: # This just checks that the generator is empty try: diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 5d5e4d8..23e4cbd 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -156,6 +156,7 @@ def test_metalearnergridsearch_smoke( assert gs.results_ is not None assert gs.results_.shape[0] == expected_n_configs assert gs.results_.index.names == expected_index_cols + assert gs.grid_size_ == expected_n_configs train_scores_cols = set( c[6:] for c in list(gs.results_.columns) if c.startswith("train_")