Skip to content

Commit

Permalink
Address issue #47
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 9, 2024
1 parent f8c0db7 commit 38e84fd
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions metalearners/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame:
row[f"test_{name}"] = value
rows.append(row)
df = pd.DataFrame(rows)
index_columns = [
c
for c in df.columns
if not c.endswith("_time")
and not c.startswith("train_")
and not c.startswith("test_")
]
sorted_cols = sorted(df.columns)
index_columns = ["metalearner"]
for model_kind in nuisance_models:
for c in sorted_cols:
if c.startswith(model_kind):
index_columns.append(c)
for model_kind in treatment_models:
for c in sorted_cols:
if c.startswith(model_kind):
index_columns.append(c)
df = df.set_index(index_columns)
return df

Expand Down

0 comments on commit 38e84fd

Please sign in to comment.