Skip to content

Commit

Permalink
Index dataframe with config
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 4, 2024
1 parent 5a6c91f commit 991b2f1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
8 changes: 8 additions & 0 deletions metalearners/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame:
row[f"test_{name}"] = value
rows.append(row)
df = pd.DataFrame(rows)
index_columns = [
c
for c in df.columns
if not c.endswith("_time")
and not c.startswith("train_")
and not c.startswith("test_")
]
df = df.set_index(index_columns)
return df


Expand Down
16 changes: 14 additions & 2 deletions tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,31 @@


@pytest.mark.parametrize(
"metalearner_factory, is_classification, base_learner_grid, param_grid, expected_n_configs",
"metalearner_factory, is_classification, base_learner_grid, param_grid, expected_n_configs, expected_index_cols",
[
(
SLearner,
False,
{"base_model": [LinearRegression, LGBMRegressor]},
{"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}},
3,
3,
),
(
SLearner,
True,
{"base_model": [LogisticRegression, LGBMClassifier]},
{"base_model": {"LGBMClassifier": {"n_estimators": [1, 2]}}},
3,
3,
),
(
TLearner,
False,
{"variant_outcome_model": [LinearRegression, LGBMRegressor]},
{"variant_outcome_model": {"LGBMRegressor": {"n_estimators": [1, 2, 3]}}},
4,
3,
),
(
XLearner,
Expand All @@ -55,6 +58,7 @@
"treatment_effect_model": {"LGBMRegressor": {"n_estimators": [1]}},
},
6,
8,
),
(
RLearner,
Expand All @@ -66,9 +70,12 @@
},
{
"propensity_model": {"LGBMClassifier": {"n_estimators": [1, 2, 3]}},
"treatment_model": {"LGBMRegressor": {"n_estimators": [1, 2, 3]}},
"treatment_model": {
"LGBMRegressor": {"n_estimators": [1, 2, 3], "learning_rate": [0.4]}
},
},
9,
7,
),
(
DRLearner,
Expand All @@ -82,6 +89,7 @@
"propensity_model": {"LGBMClassifier": {"n_estimators": [1, 2, 3, 4]}},
},
4,
5,
),
],
)
Expand All @@ -91,6 +99,7 @@ def test_metalearnergridsearch_smoke(
base_learner_grid,
param_grid,
expected_n_configs,
expected_index_cols,
grid_search_data,
):
X, y_class, y_reg, w, X_test, y_test_class, y_test_reg, w_test = grid_search_data
Expand All @@ -116,6 +125,7 @@ def test_metalearnergridsearch_smoke(
gs.fit(X, y, w, X_test, y_test, w_test)
assert gs.results_ is not None
assert gs.results_.shape[0] == expected_n_configs
assert len(gs.results_.index.names) == expected_index_cols

train_scores_cols = set(
c[6:] for c in list(gs.results_.columns) if c.startswith("train_")
Expand Down Expand Up @@ -175,6 +185,7 @@ def test_metalearnergridsearch_reuse_nuisance_smoke(grid_search_data):
}
assert gs.results_ is not None
assert gs.results_.shape[0] == 8
assert len(gs.results_.index.names) == 7


def test_metalearnergridsearch_reuse_propensity_smoke(grid_search_data):
Expand Down Expand Up @@ -220,3 +231,4 @@ def test_metalearnergridsearch_reuse_propensity_smoke(grid_search_data):
assert raw_result.metalearner._prefitted_nuisance_models == {PROPENSITY_MODEL}
assert gs.results_ is not None
assert gs.results_.shape[0] == 2
assert len(gs.results_.index.names) == 5

0 comments on commit 991b2f1

Please sign in to comment.