From 38e84fdfcfb6be4f3b8bffba4757b6fd00d9617c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Tue, 9 Jul 2024 16:57:36 +0200 Subject: [PATCH] Address issue #47 --- metalearners/grid_search.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index 1d85f637..cfbc480c 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -115,13 +115,16 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame: row[f"test_{name}"] = value rows.append(row) df = pd.DataFrame(rows) - index_columns = [ - c - for c in df.columns - if not c.endswith("_time") - and not c.startswith("train_") - and not c.startswith("test_") - ] + sorted_cols = sorted(df.columns) + index_columns = ["metalearner"] + for model_kind in nuisance_models: + for c in sorted_cols: + if c.startswith(model_kind): + index_columns.append(c) + for model_kind in treatment_models: + for c in sorted_cols: + if c.startswith(model_kind): + index_columns.append(c) df = df.set_index(index_columns) return df