Skip to content

Commit

Permalink
Reduce memory usage by not creating metalearner object
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 17, 2024
1 parent 44fa6ec commit 629999d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
49 changes: 31 additions & 18 deletions metalearners/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

@dataclass(frozen=True)
class _FitAndScoreJob:
metalearner: MetaLearner
metalearner_factory: type[MetaLearner]
metalearner_params: dict[str, Any]
X_train: Matrix
y_train: Vector
w_train: Vector
Expand All @@ -44,21 +45,20 @@ class GSResult:

def _fit_and_score(job: _FitAndScoreJob) -> GSResult:
start_time = time.time()
job.metalearner.fit(
job.X_train, job.y_train, job.w_train, **job.metalerner_fit_params
)
ml = job.metalearner_factory(**job.metalearner_params)
ml.fit(job.X_train, job.y_train, job.w_train, **job.metalerner_fit_params)
fit_time = time.time() - start_time
start_time = time.time()

train_scores = job.metalearner.evaluate(
train_scores = ml.evaluate(
X=job.X_train,
y=job.y_train,
w=job.w_train,
is_oos=False,
scoring=job.scoring,
)
if job.X_test is not None and job.y_test is not None and job.w_test is not None:
test_scores = job.metalearner.evaluate(
test_scores = ml.evaluate(
X=job.X_test,
y=job.y_test,
w=job.w_test,
Expand All @@ -70,7 +70,7 @@ def _fit_and_score(job: _FitAndScoreJob) -> GSResult:
test_scores = None
score_time = time.time() - start_time
return GSResult(
metalearner=job.metalearner,
metalearner=ml,
fit_time=fit_time,
score_time=score_time,
train_scores=train_scores,
Expand Down Expand Up @@ -310,20 +310,33 @@ def fit(
}
propensity_model_params = params.get(PROPENSITY_MODEL, None)

ml = self.metalearner_factory(
**self.metalearner_params,
nuisance_model_factory=nuisance_model_factory,
treatment_model_factory=treatment_model_factory,
propensity_model_factory=propensity_model_factory,
nuisance_model_params=nuisance_model_params,
treatment_model_params=treatment_model_params,
propensity_model_params=propensity_model_params,
random_state=self.random_state,
)
grid_metalearner_params = {
"nuisance_model_factory": nuisance_model_factory,
"treatment_model_factory": treatment_model_factory,
"propensity_model_factory": propensity_model_factory,
"nuisance_model_params": nuisance_model_params,
"treatment_model_params": treatment_model_params,
"propensity_model_params": propensity_model_params,
"random_state": self.random_state,
}

if (
len(
shared_keys := set(grid_metalearner_params.keys())
& set(self.metalearner_params.keys())
)
> 0
):
raise ValueError(
f"{shared_keys} should not be specified in metalearner_params as "
"they are used internally. Please use the correct parameters."
)

jobs.append(
_FitAndScoreJob(
metalearner=ml,
metalearner_factory=self.metalearner_factory,
metalearner_params=dict(self.metalearner_params)
| grid_metalearner_params,
X_train=X,
y_train=y,
w_train=w,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,26 @@ def test_metalearnergridsearch_store(
gs.fit(X, y, w, X_test, y_test, w_test)
assert isinstance(gs.raw_results_, expected_type_raw_results)
assert isinstance(gs.results_, expected_type_results)


def test_metalearnergridsearch_error(grid_search_data):
X, _, y, w, X_test, _, y_test, w_test = grid_search_data
n_variants = len(np.unique(w))

metalearner_params = {
"is_classification": False,
"n_variants": n_variants,
"n_folds": 2,
"random_state": 1,
}

gs = MetaLearnerGridSearch(
metalearner_factory=SLearner,
metalearner_params=metalearner_params,
base_learner_grid={"base_model": [LinearRegression, LGBMRegressor]},
param_grid={"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}},
)
with pytest.raises(
ValueError, match="should not be specified in metalearner_params"
):
gs.fit(X, y, w, X_test, y_test, w_test)

0 comments on commit 629999d

Please sign in to comment.