Skip to content

Commit

Permalink
added PR curve figure
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrlaltaf committed Sep 10, 2024
1 parent f6ddffc commit 94e06c8
Show file tree
Hide file tree
Showing 21 changed files with 1,214,881 additions and 11 deletions.
72 changes: 61 additions & 11 deletions generate_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from matplotlib import pyplot as plt
import numpy as np
from sklearn.metrics import auc, roc_curve
from sklearn.metrics import auc, precision_recall_curve, roc_curve


def read_file(filepath: str) -> dict:
Expand Down Expand Up @@ -33,6 +33,15 @@ def get_roc_data(data_df: dict) -> list:
return fpr, tpr, threshold, roc_auc


def get_pr_data(data_df: dict) -> list:
y = np.array(data_df["y_true"])
scores = np.array(data_df["y_score"])
precision, recall , _= precision_recall_curve(y, scores)
pr_auc = auc(recall, precision)

return precision, recall, pr_auc


def create_plot(ax, x_data: list, y_data: list, auc: float, type: str, color) -> None:
ax.plot(
x_data,
Expand All @@ -46,21 +55,18 @@ def create_plot(ax, x_data: list, y_data: list, auc: float, type: str, color) ->
def main():
print("Generating figures")
species_list = ["elegans", "fly", "bsub", "yeast", "zfish"]
file_directory = "./results/final-non-inferred-complete/"
final_data = defaultdict(list)

for species in species_list:
overlapping_path = Path(
f"./results/final-inferred-complete/{species}/overlapping_neighbor_data.csv"
file_directory, f"{species}/overlapping_neighbor_data.csv"
)
hypergeometric_path = Path(
f"./results/final-inferred-complete/{species}/hypergeometric_distribution.csv"
)
degree_path = Path(
f"./results/final-inferred-complete/{species}/protein_degree_v3_data.csv"
)
rw_path = Path(
f"./results/final-inferred-complete/{species}/random_walk_data_v2.csv"
file_directory, f"{species}/hypergeometric_distribution.csv"
)
degree_path = Path(file_directory, f"{species}/protein_degree_v3_data.csv")
rw_path = Path(file_directory, f"{species}/random_walk_data_v2.csv")

species_path = [overlapping_path, hypergeometric_path, degree_path, rw_path]

Expand All @@ -74,6 +80,9 @@ def main():
tpr_list = []
threshold_list = []
roc_auc_list = []
precision_list = []
recall_list = []
pr_auc_list = []

for data in methods:
fpr, tpr, threshold, roc_auc = get_roc_data(data)
Expand All @@ -82,10 +91,18 @@ def main():
threshold_list.append(threshold)
roc_auc_list.append(roc_auc)

precision, recall, pr_auc = get_pr_data(data)
precision_list.append(precision)
recall_list.append(recall)
pr_auc_list.append(pr_auc)

species_data = {
"fpr": fpr_list,
"tpr": tpr_list,
"roc": roc_auc_list,
"precision": precision_list,
"recall": recall_list,
"pr": pr_auc_list,
"method": ["Overlapping", "Hypergeometric", "Degree", "RW"],
}
final_data[species].append(species_data)
Expand All @@ -105,17 +122,50 @@ def main():
final_data[species][0]["tpr"][i],
final_data[species][0]["roc"][i],
final_data[species][0]["method"][i],
colors[i]
colors[i],
)

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title(f"ROC Curve for {species.capitalize()}")
ax.set_title(f"{species.capitalize()}")
ax.legend(loc="lower right")

axes[5].set_visible(False)
fig.suptitle("ROC Curve for All Species w/ Complete Inferred Networks", fontsize=20)
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Create a 2x3 grid of subplots
axes = axes.flatten()
colors = ["red", "green", "blue", "orange", "purple"]

for idx, species in enumerate(species_list):
ax = axes[idx] # Get the subplot axis for the current species

for i in range(len(final_data[species][0]["method"])):
create_plot(
ax,
final_data[species][0]["precision"][i],
final_data[species][0]["recall"][i],
final_data[species][0]["pr"][i],
final_data[species][0]["method"][i],
colors[i],
)

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("Precision")
ax.set_ylabel("Recall")
ax.set_title(f"{species.capitalize()}")
ax.legend(loc="lower right")

axes[5].set_visible(False)
fig.suptitle(
"Precision/Recall Curve for All Species w/ Complete Inferred Networks",
fontsize=20,
)
plt.tight_layout()
plt.show()

Expand Down
Loading

0 comments on commit 94e06c8

Please sign in to comment.