Skip to content

Commit

Permalink
Merge branch 'main' into fix_lime_example
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC authored Jul 10, 2024
2 parents 67a3480 + 1580c96 commit bbbfe28
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
Changelog
=========

0.6.1 (2024-07-xx)
------------------

**Other changes**

* Changed the index columns order in ``MetaLearnerGridSearch.results_``.

0.6.0 (2024-07-08)
------------------

Expand Down
23 changes: 13 additions & 10 deletions metalearners/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame:
for result in results:
row: dict[str, str | int | float] = {}
row["metalearner"] = result.metalearner.__class__.__name__
nuisance_models = (
nuisance_models = sorted(
set(result.metalearner.nuisance_model_specifications().keys())
- result.metalearner._prefitted_nuisance_models
)
treatment_models = set(
result.metalearner.treatment_model_specifications().keys()
treatment_models = sorted(
set(result.metalearner.treatment_model_specifications().keys())
)
for model_kind in nuisance_models:
row[model_kind] = result.metalearner.nuisance_model_factory[
Expand All @@ -115,13 +115,16 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame:
row[f"test_{name}"] = value
rows.append(row)
df = pd.DataFrame(rows)
index_columns = [
c
for c in df.columns
if not c.endswith("_time")
and not c.startswith("train_")
and not c.startswith("test_")
]
sorted_cols = sorted(df.columns)
index_columns = ["metalearner"]
for model_kind in nuisance_models:
for c in sorted_cols:
if c.startswith(model_kind):
index_columns.append(c)
for model_kind in treatment_models:
for c in sorted_cols:
if c.startswith(model_kind):
index_columns.append(c)
df = df.set_index(index_columns)
return df

Expand Down
41 changes: 34 additions & 7 deletions tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,27 @@
{"base_model": [LinearRegression, LGBMRegressor]},
{"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}},
3,
3,
["metalearner", "base_model", "base_model_n_estimators"],
),
(
SLearner,
True,
{"base_model": [LogisticRegression, LGBMClassifier]},
{"base_model": {"LGBMClassifier": {"n_estimators": [1, 2]}}},
3,
3,
["metalearner", "base_model", "base_model_n_estimators"],
),
(
TLearner,
False,
{"variant_outcome_model": [LinearRegression, LGBMRegressor]},
{"variant_outcome_model": {"LGBMRegressor": {"n_estimators": [1, 2, 3]}}},
4,
3,
[
"metalearner",
"variant_outcome_model",
"variant_outcome_model_n_estimators",
],
),
(
XLearner,
Expand All @@ -58,7 +62,16 @@
"treatment_effect_model": {"LGBMRegressor": {"n_estimators": [1]}},
},
6,
8,
[
"metalearner",
"propensity_model",
"propensity_model_n_estimators",
"variant_outcome_model",
"control_effect_model",
"control_effect_model_n_estimators",
"treatment_effect_model",
"treatment_effect_model_n_estimators",
],
),
(
RLearner,
Expand All @@ -75,7 +88,15 @@
},
},
9,
7,
[
"metalearner",
"outcome_model",
"propensity_model",
"propensity_model_n_estimators",
"treatment_model",
"treatment_model_learning_rate",
"treatment_model_n_estimators",
],
),
(
DRLearner,
Expand All @@ -89,7 +110,13 @@
"propensity_model": {"LGBMClassifier": {"n_estimators": [1, 2, 3, 4]}},
},
4,
5,
[
"metalearner",
"propensity_model",
"propensity_model_n_estimators",
"variant_outcome_model",
"treatment_model",
],
),
],
)
Expand Down Expand Up @@ -125,7 +152,7 @@ def test_metalearnergridsearch_smoke(
gs.fit(X, y, w, X_test, y_test, w_test)
assert gs.results_ is not None
assert gs.results_.shape[0] == expected_n_configs
assert len(gs.results_.index.names) == expected_index_cols
assert gs.results_.index.names == expected_index_cols

train_scores_cols = set(
c[6:] for c in list(gs.results_.columns) if c.startswith("train_")
Expand Down

0 comments on commit bbbfe28

Please sign in to comment.