Skip to content

Commit

Permalink
linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
1andrin committed Oct 24, 2024
1 parent e8470d9 commit 7f90d95
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
4 changes: 2 additions & 2 deletions causaltune/erupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def probabilistic_erupt_score(
print(
f"\nDebugging Probabilistic ERUPT for estimator: {est.__class__.__name__}"
)
print(f"CATE estimate summary:")
print("CATE estimate summary:")
print(f"Mean: {np.mean(cate_estimate):.4f}")
print(f"Std: {np.std(cate_estimate):.4f}")
print(f"Min: {np.min(cate_estimate):.4f}")
Expand Down Expand Up @@ -257,7 +257,7 @@ def probabilistic_erupt_score(
)
effect_stds = np.broadcast_to(effect_stds, cate_estimate.shape)

print(f"\nStandard errors summary:")
print("\nStandard errors summary:")
print(f"Mean: {np.mean(effect_stds):.4f}")
print(f"Std: {np.std(effect_stds):.4f}")
print(f"Min: {np.min(effect_stds):.4f}")
Expand Down
27 changes: 11 additions & 16 deletions notebooks/RunExperiments/experiment_runner.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
import os
import sys

# Ensure CausalTune is in the Python path
root_path = os.path.realpath("../../../..")
sys.path.append(os.path.join(root_path, "causaltune"))

import pickle
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import copy
import argparse
from typing import Union
from sklearn.model_selection import train_test_split
from datetime import datetime
import colorsys
import warnings

warnings.filterwarnings("ignore")

# Ensure CausalTune is in the Python path
root_path = os.path.realpath("../../../..")
sys.path.append(os.path.join(root_path, "causaltune"))

from causaltune import CausalTune
from causaltune.data_utils import CausalityDataset
from causaltune.datasets import (
generate_synthetic_data,
generate_linear_synthetic_data,
load_dataset,
)
from causaltune.models.passthrough import passthrough_model
Expand Down Expand Up @@ -296,10 +291,10 @@ def plot_grid(title):
plt.suptitle(f"Estimated CATEs vs. True CATEs: {title}", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(
os.path.join(out_dir, f"CATE_grid.pdf"), format="pdf", bbox_inches="tight"
os.path.join(out_dir, "CATE_grid.pdf"), format="pdf", bbox_inches="tight"
)
plt.savefig(
os.path.join(out_dir, f"CATE_grid.png"), format="png", bbox_inches="tight"
os.path.join(out_dir, "CATE_grid.png"), format="png", bbox_inches="tight"
)
plt.close()

Expand Down Expand Up @@ -376,10 +371,10 @@ def plot_mse_grid(title):
plt.suptitle(f"MSE vs. Scores: {title}", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(
os.path.join(out_dir, f"MSE_grid.pdf"), format="pdf", bbox_inches="tight"
os.path.join(out_dir, "MSE_grid.pdf"), format="pdf", bbox_inches="tight"
)
plt.savefig(
os.path.join(out_dir, f"MSE_grid.png"), format="png", bbox_inches="tight"
os.path.join(out_dir, "MSE_grid.png"), format="png", bbox_inches="tight"
)
plt.close()

Expand All @@ -388,10 +383,10 @@ def plot_mse_grid(title):
ax_legend.legend(handles=legend_elements, loc="center", fontsize=10)
ax_legend.axis("off")
plt.savefig(
os.path.join(out_dir, f"MSE_legend.pdf"), format="pdf", bbox_inches="tight"
os.path.join(out_dir, "MSE_legend.pdf"), format="pdf", bbox_inches="tight"
)
plt.savefig(
os.path.join(out_dir, f"MSE_legend.png"), format="png", bbox_inches="tight"
os.path.join(out_dir, "MSE_legend.png"), format="png", bbox_inches="tight"
)
plt.close()

Expand Down

0 comments on commit 7f90d95

Please sign in to comment.