Skip to content

Commit

Permalink
ran pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
velezbeltran committed Oct 16, 2024
1 parent b74127e commit 6a8eef0
Showing 1 changed file with 56 additions and 14 deletions.
70 changes: 56 additions & 14 deletions testbed/notebooks/paper_notebooks/speed_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@
import pandas as pd
import pickle

from matplotlib import rcParams


# make plots pretty
sns.set(style="whitegrid")
# sns set font for presentations and such
sns.set_context("talk")

# Set the font to Arial
rcParams["font.family"] = "sans-serif"
rcParams["font.sans-serif"] = ["Arial"] # Use Arial if available


N_ATTEMPTS = 5
FILE_NAME_TRAIN_TIMES = "train_times.pdf"
Expand Down Expand Up @@ -97,21 +107,20 @@ def make_datapoint_from_dataset(
t_samples = []
n_sampled = None


for _ in range(N_ATTEMPTS):
t_start = time.time()
model = Treeffuser()
model.fit(x, y)
t_fit = time.time() - t_start
print(f"Time taken to fit model on {dataset_name}: {t_fit}")

x_to_sample = np.concatenate([x, x, x, x, x, x ])
x_to_sample = x_to_sample[: 1000]
x_to_sample = np.concatenate([x, x, x, x, x, x])
x_to_sample = x_to_sample[:1000]
n_sampled = len(x_to_sample)

t_start = time.time()
_ = model.sample(x_to_sample, n_samples=1)
t_sample = (time.time() - t_start)
t_sample = time.time() - t_start
print(f"Time taken to sample from model on {dataset_name}: {t_sample}")

t_fits.append(t_fit)
Expand Down Expand Up @@ -171,20 +180,37 @@ def plot_train_times(datapoints: List[_Datapoint], save_pth: str, annotate=True)
labels = [f"{NAMES_TO_PLOT[dp.dataset_name]}\n{dp.dataset_shape}" for dp in datapoints]

# Create scatter plot
plt.figure(figsize=(10, 6))
plt.figure(figsize=(15, 9))
for i in range(len(datapoints)):
plt.errorbar(x[i], y[i], yerr=yerr[i], fmt="o", capsize=5, capthick=1, ecolor="red", color=colors[i], label=labels[i])

plt.errorbar(
x[i],
y[i],
yerr=yerr[i],
fmt="o",
capsize=5,
capthick=1,
ecolor="red",
color=colors[i],
label=labels[i],
)

# Annotate each point with the dataset name and shape
if annotate:
for i, label in enumerate(labels):
plt.annotate(label, (x[i], y[i]), textcoords="offset points", xytext=(5, 10))
# make the text not too small
plt.annotate(
label,
(x[i], y[i]),
textcoords="offset points",
xytext=(5, 13),
ha="center",
fontsize=18,
)
else:
plt.legend()
# plot legend but put it
plt.legend(loc="lower right", fontsize=18, ncols=2)

# Add titles and labels
plt.title("Training Times vs. Dataset Size")
plt.xlabel("Number of Samples in Dataset")
plt.ylabel("Mean Training Time (seconds)")
plt.grid(True)
Expand All @@ -193,6 +219,16 @@ def plot_train_times(datapoints: List[_Datapoint], save_pth: str, annotate=True)
plt.xlim(0, 1.2 * max(x))
plt.ylim(0, 1.2 * max(y))

# Make the font of everything larger
plt.tick_params(axis="both", which="major", labelsize=18)
plt.tick_params(axis="both", which="minor", labelsize=16)
plt.xlabel("Number of Samples in Dataset", fontsize=25)
plt.ylabel("Mean Training Time (seconds)", fontsize=25)
# plt.title("Training Times vs. Dataset Size", fontsize=28)

# tight layout
plt.tight_layout()

# Save plot
plt.savefig(save_pth)
# save also png
Expand All @@ -215,7 +251,7 @@ def make_table_for_sample_times(datapoints: List[_Datapoint], save_pth: str):

def create_a_datapoint_per_dataset(
dataset_names: List[str] = None,
) -> List[_Datapoint]:
) -> List[_Datapoint]:

# remove m5 subset / not uciml datasets
datapoints = []
Expand Down Expand Up @@ -252,10 +288,13 @@ def get_pkls(pkl_name: str):
return pickle.load(f)
except FileNotFoundError:
return None


def save_pkls(pkl_name: str, data):
with open(pkl_name, "wb") as f:
pickle.dump(data, f)


if __name__ == "__main__":
args = parse_args()
if not os.path.exists(args.out_dir):
Expand All @@ -269,13 +308,16 @@ def save_pkls(pkl_name: str, data):
datapoints = create_a_datapoint_per_dataset(dataset_names)
save_pkls(os.path.join(args.out_dir, FILE_NAME_DATAPOINT_PER_DATASET), datapoints)

plot_train_times(datapoints, os.path.join(args.out_dir, FILE_NAME_TRAIN_TIMES), annotate=False)
plot_train_times(
datapoints, os.path.join(args.out_dir, FILE_NAME_TRAIN_TIMES), annotate=False
)
make_table_for_sample_times(datapoints, os.path.join(args.out_dir, FILE_NAME_SAMPLE_TIMES))


datapoints = get_pkls(os.path.join(args.out_dir, FILE_NAME_DATAPOINT_PER_FRACTION))
if datapoints is None or not ATTEMPT_LOAD_DATAPOINTS:
datapoints = create_a_datapoint_per_fraction_of_dataset("m5_subset", N_SUBSETS)
save_pkls(os.path.join(args.out_dir, FILE_NAME_DATAPOINT_PER_FRACTION), datapoints)

plot_train_times(datapoints, os.path.join(args.out_dir, FILE_NAME_M5_SUBSET), annotate=True)
plot_train_times(
datapoints, os.path.join(args.out_dir, FILE_NAME_M5_SUBSET), annotate=True
)

0 comments on commit 6a8eef0

Please sign in to comment.