From 629999dbbfd2f846a3e92b9b56be63ad784b906e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Wed, 17 Jul 2024 15:40:09 +0200 Subject: [PATCH] Reduce memory usage by not creating metalearner object --- metalearners/grid_search.py | 49 +++++++++++++++++++++++-------------- tests/test_grid_search.py | 23 +++++++++++++++++ 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index 9879092..505cb4e 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -17,7 +17,8 @@ @dataclass(frozen=True) class _FitAndScoreJob: - metalearner: MetaLearner + metalearner_factory: type[MetaLearner] + metalearner_params: dict[str, Any] X_train: Matrix y_train: Vector w_train: Vector @@ -44,13 +45,12 @@ class GSResult: def _fit_and_score(job: _FitAndScoreJob) -> GSResult: start_time = time.time() - job.metalearner.fit( - job.X_train, job.y_train, job.w_train, **job.metalerner_fit_params - ) + ml = job.metalearner_factory(**job.metalearner_params) + ml.fit(job.X_train, job.y_train, job.w_train, **job.metalerner_fit_params) fit_time = time.time() - start_time start_time = time.time() - train_scores = job.metalearner.evaluate( + train_scores = ml.evaluate( X=job.X_train, y=job.y_train, w=job.w_train, @@ -58,7 +58,7 @@ def _fit_and_score(job: _FitAndScoreJob) -> GSResult: scoring=job.scoring, ) if job.X_test is not None and job.y_test is not None and job.w_test is not None: - test_scores = job.metalearner.evaluate( + test_scores = ml.evaluate( X=job.X_test, y=job.y_test, w=job.w_test, @@ -70,7 +70,7 @@ def _fit_and_score(job: _FitAndScoreJob) -> GSResult: test_scores = None score_time = time.time() - start_time return GSResult( - metalearner=job.metalearner, + metalearner=ml, fit_time=fit_time, score_time=score_time, train_scores=train_scores, @@ -310,20 +310,33 @@ def fit( } propensity_model_params = params.get(PROPENSITY_MODEL, None) - ml = self.metalearner_factory( - **self.metalearner_params, - nuisance_model_factory=nuisance_model_factory, - treatment_model_factory=treatment_model_factory, - propensity_model_factory=propensity_model_factory, - nuisance_model_params=nuisance_model_params, - treatment_model_params=treatment_model_params, - propensity_model_params=propensity_model_params, - random_state=self.random_state, - ) + grid_metalearner_params = { + "nuisance_model_factory": nuisance_model_factory, + "treatment_model_factory": treatment_model_factory, + "propensity_model_factory": propensity_model_factory, + "nuisance_model_params": nuisance_model_params, + "treatment_model_params": treatment_model_params, + "propensity_model_params": propensity_model_params, + "random_state": self.random_state, + } + + if ( + len( + shared_keys := set(grid_metalearner_params.keys()) + & set(self.metalearner_params.keys()) + ) + > 0 + ): + raise ValueError( + f"{shared_keys} should not be specified in metalearner_params as " + "they are used internally. Please use the correct parameters." + ) jobs.append( _FitAndScoreJob( - metalearner=ml, + metalearner_factory=self.metalearner_factory, + metalearner_params=dict(self.metalearner_params) + | grid_metalearner_params, X_train=X, y_train=y, w_train=w, diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 5949b86..5d5e4d8 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -301,3 +301,26 @@ def test_metalearnergridsearch_store( gs.fit(X, y, w, X_test, y_test, w_test) assert isinstance(gs.raw_results_, expected_type_raw_results) assert isinstance(gs.results_, expected_type_results) + + +def test_metalearnergridsearch_error(grid_search_data): + X, _, y, w, X_test, _, y_test, w_test = grid_search_data + n_variants = len(np.unique(w)) + + metalearner_params = { + "is_classification": False, + "n_variants": n_variants, + "n_folds": 2, + "random_state": 1, + } + + gs = MetaLearnerGridSearch( + metalearner_factory=SLearner, + metalearner_params=metalearner_params, + base_learner_grid={"base_model": [LinearRegression, LGBMRegressor]}, + param_grid={"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}}, + ) + with pytest.raises( + ValueError, match="should not be specified in metalearner_params" + ): + gs.fit(X, y, w, X_test, y_test, w_test)