diff --git a/causaltune/erupt.py b/causaltune/erupt.py index a1c57c50..2dff7c17 100644 --- a/causaltune/erupt.py +++ b/causaltune/erupt.py @@ -202,7 +202,7 @@ def probabilistic_erupt_score( print( f"\nDebugging Probabilistic ERUPT for estimator: {est.__class__.__name__}" ) - print(f"CATE estimate summary:") + print("CATE estimate summary:") print(f"Mean: {np.mean(cate_estimate):.4f}") print(f"Std: {np.std(cate_estimate):.4f}") print(f"Min: {np.min(cate_estimate):.4f}") @@ -257,7 +257,7 @@ def probabilistic_erupt_score( ) effect_stds = np.broadcast_to(effect_stds, cate_estimate.shape) - print(f"\nStandard errors summary:") + print("\nStandard errors summary:") print(f"Mean: {np.mean(effect_stds):.4f}") print(f"Std: {np.std(effect_stds):.4f}") print(f"Min: {np.min(effect_stds):.4f}") diff --git a/notebooks/RunExperiments/experiment_runner.py b/notebooks/RunExperiments/experiment_runner.py index a81eeac9..49a33c3f 100644 --- a/notebooks/RunExperiments/experiment_runner.py +++ b/notebooks/RunExperiments/experiment_runner.py @@ -1,29 +1,24 @@ import os import sys + +# Ensure CausalTune is in the Python path +root_path = os.path.realpath("../../../..") +sys.path.append(os.path.join(root_path, "causaltune")) + import pickle import numpy as np -import pandas as pd import matplotlib import matplotlib.pyplot as plt import copy import argparse -from typing import Union from sklearn.model_selection import train_test_split from datetime import datetime -import colorsys import warnings warnings.filterwarnings("ignore") -# Ensure CausalTune is in the Python path -root_path = os.path.realpath("../../../..") -sys.path.append(os.path.join(root_path, "causaltune")) - from causaltune import CausalTune -from causaltune.data_utils import CausalityDataset from causaltune.datasets import ( - generate_synthetic_data, - generate_linear_synthetic_data, load_dataset, ) from causaltune.models.passthrough import passthrough_model @@ -296,10 +291,10 @@ def plot_grid(title): plt.suptitle(f"Estimated CATEs vs. True CATEs: {title}", fontsize=16) plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.savefig( - os.path.join(out_dir, f"CATE_grid.pdf"), format="pdf", bbox_inches="tight" + os.path.join(out_dir, "CATE_grid.pdf"), format="pdf", bbox_inches="tight" ) plt.savefig( - os.path.join(out_dir, f"CATE_grid.png"), format="png", bbox_inches="tight" + os.path.join(out_dir, "CATE_grid.png"), format="png", bbox_inches="tight" ) plt.close() @@ -376,10 +371,10 @@ def plot_mse_grid(title): plt.suptitle(f"MSE vs. Scores: {title}", fontsize=16) plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.savefig( - os.path.join(out_dir, f"MSE_grid.pdf"), format="pdf", bbox_inches="tight" + os.path.join(out_dir, "MSE_grid.pdf"), format="pdf", bbox_inches="tight" ) plt.savefig( - os.path.join(out_dir, f"MSE_grid.png"), format="png", bbox_inches="tight" + os.path.join(out_dir, "MSE_grid.png"), format="png", bbox_inches="tight" ) plt.close() @@ -388,10 +383,10 @@ def plot_mse_grid(title): ax_legend.legend(handles=legend_elements, loc="center", fontsize=10) ax_legend.axis("off") plt.savefig( - os.path.join(out_dir, f"MSE_legend.pdf"), format="pdf", bbox_inches="tight" + os.path.join(out_dir, "MSE_legend.pdf"), format="pdf", bbox_inches="tight" ) plt.savefig( - os.path.join(out_dir, f"MSE_legend.png"), format="png", bbox_inches="tight" + os.path.join(out_dir, "MSE_legend.png"), format="png", bbox_inches="tight" ) plt.close()