From 991b2f1696d0216339dffe12bb481d1ffce5d609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Thu, 4 Jul 2024 13:21:20 +0200 Subject: [PATCH] Index dataframe with config --- metalearners/grid_search.py | 8 ++++++++ tests/test_grid_search.py | 16 ++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index efecfbb..60e2dad 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -110,6 +110,14 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame: row[f"test_{name}"] = value rows.append(row) df = pd.DataFrame(rows) + index_columns = [ + c + for c in df.columns + if not c.endswith("_time") + and not c.startswith("train_") + and not c.startswith("test_") + ] + df = df.set_index(index_columns) return df diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 93ffff0..fd953ff 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize( - "metalearner_factory, is_classification, base_learner_grid, param_grid, expected_n_configs", + "metalearner_factory, is_classification, base_learner_grid, param_grid, expected_n_configs, expected_index_cols", [ ( SLearner, @@ -25,6 +25,7 @@ {"base_model": [LinearRegression, LGBMRegressor]}, {"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}}, 3, + 3, ), ( SLearner, @@ -32,6 +33,7 @@ {"base_model": [LogisticRegression, LGBMClassifier]}, {"base_model": {"LGBMClassifier": {"n_estimators": [1, 2]}}}, 3, + 3, ), ( TLearner, @@ -39,6 +41,7 @@ {"variant_outcome_model": [LinearRegression, LGBMRegressor]}, {"variant_outcome_model": {"LGBMRegressor": {"n_estimators": [1, 2, 3]}}}, 4, + 3, ), ( XLearner, @@ -55,6 +58,7 @@ "treatment_effect_model": {"LGBMRegressor": {"n_estimators": [1]}}, }, 6, + 8, ), ( RLearner, @@ -66,9 +70,12 @@ }, { "propensity_model": {"LGBMClassifier": {"n_estimators": [1, 2, 3]}}, - "treatment_model": {"LGBMRegressor": {"n_estimators": [1, 2, 3]}}, + "treatment_model": { + "LGBMRegressor": {"n_estimators": [1, 2, 3], "learning_rate": [0.4]} + }, }, 9, + 7, ), ( DRLearner, @@ -82,6 +89,7 @@ "propensity_model": {"LGBMClassifier": {"n_estimators": [1, 2, 3, 4]}}, }, 4, + 5, ), ], ) @@ -91,6 +99,7 @@ def test_metalearnergridsearch_smoke( base_learner_grid, param_grid, expected_n_configs, + expected_index_cols, grid_search_data, ): X, y_class, y_reg, w, X_test, y_test_class, y_test_reg, w_test = grid_search_data @@ -116,6 +125,7 @@ def test_metalearnergridsearch_smoke( gs.fit(X, y, w, X_test, y_test, w_test) assert gs.results_ is not None assert gs.results_.shape[0] == expected_n_configs + assert len(gs.results_.index.names) == expected_index_cols train_scores_cols = set( c[6:] for c in list(gs.results_.columns) if c.startswith("train_") @@ -175,6 +185,7 @@ def test_metalearnergridsearch_reuse_nuisance_smoke(grid_search_data): } assert gs.results_ is not None assert gs.results_.shape[0] == 8 + assert len(gs.results_.index.names) == 7 def test_metalearnergridsearch_reuse_propensity_smoke(grid_search_data): @@ -220,3 +231,4 @@ def test_metalearnergridsearch_reuse_propensity_smoke(grid_search_data): assert raw_result.metalearner._prefitted_nuisance_models == {PROPENSITY_MODEL} assert gs.results_ is not None assert gs.results_.shape[0] == 2 + assert len(gs.results_.index.names) == 5