Skip to content

Commit

Permalink
fixed PR curve and added images folder
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrlaltaf committed Sep 10, 2024
1 parent bb4822b commit eff16df
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 112 deletions.
253 changes: 141 additions & 112 deletions generate_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,119 +55,132 @@ def create_plot(ax, x_data: list, y_data: list, auc: float, type: str, color) ->
def main():
print("Generating figures")
species_list = ["elegans", "fly", "bsub", "yeast", "zfish"]
file_directory = "./results/final-non-inferred-complete/"
final_category_data = defaultdict(list)

for species in species_list:
overlapping_path = Path(
file_directory, f"{species}/overlapping_neighbor_data.csv"
)
hypergeometric_path = Path(
file_directory, f"{species}/hypergeometric_distribution.csv"
)
degree_path = Path(file_directory, f"{species}/protein_degree_v3_data.csv")
rw_path = Path(file_directory, f"{species}/random_walk_data_v2.csv")

species_path = [overlapping_path, hypergeometric_path, degree_path, rw_path]

methods = []
for path in species_path:
data = read_file(path)
methods.append(data)

# calculate AUC values
fpr_list = []
tpr_list = []
threshold_list = []
roc_auc_list = []
precision_list = []
recall_list = []
pr_auc_list = []

for data in methods:
fpr, tpr, threshold, roc_auc = get_roc_data(data)
fpr_list.append(fpr)
tpr_list.append(tpr)
threshold_list.append(threshold)
roc_auc_list.append(roc_auc)

precision, recall, pr_auc = get_pr_data(data)
precision_list.append(precision)
recall_list.append(recall)
pr_auc_list.append(pr_auc)

species_data = {
"fpr": fpr_list,
"tpr": tpr_list,
"roc": roc_auc_list,
"precision": precision_list,
"recall": recall_list,
"pr": pr_auc_list,
"method": ["Overlapping", "Hypergeometric", "Degree", "RW"],
}
final_category_data[species].append(species_data)

# Create a figure with 2 subplots (one for each species)
fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Create a 2x3 grid of subplots
axes = axes.flatten()
colors = ["red", "green", "blue", "orange", "purple"]
file_directories = [
"./results/final-non-inferred-complete/",
"./results/final-inferred-complete/",
]
subplot_titles = ["Complete Non Inferred Networks", "Complete Inferred Networks"]
k = 0
for directory in file_directories:
final_category_data = defaultdict(list)

for idx, species in enumerate(species_list):
ax = axes[idx] # Get the subplot axis for the current species

for i in range(len(final_category_data[species][0]["method"])):
create_plot(
ax,
final_category_data[species][0]["fpr"][i],
final_category_data[species][0]["tpr"][i],
final_category_data[species][0]["roc"][i],
final_category_data[species][0]["method"][i],
colors[i],
for species in species_list:
overlapping_path = Path(
directory, f"{species}/overlapping_neighbor_data.csv"
)
hypergeometric_path = Path(
directory, f"{species}/hypergeometric_distribution.csv"
)
degree_path = Path(directory, f"{species}/protein_degree_v3_data.csv")
rw_path = Path(directory, f"{species}/random_walk_data_v2.csv")

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title(f"{species.capitalize()}")
ax.legend(loc="lower right")

axes[5].set_visible(False)
fig.suptitle("ROC Curve for All Species w/ Complete Inferred Networks", fontsize=20)
plt.tight_layout()
plt.show()
species_path = [overlapping_path, hypergeometric_path, degree_path, rw_path]

fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Create a 2x3 grid of subplots
axes = axes.flatten()
colors = ["red", "green", "blue", "orange", "purple"]
methods = []
for path in species_path:
data = read_file(path)
methods.append(data)

for idx, species in enumerate(species_list):
ax = axes[idx] # Get the subplot axis for the current species

for i in range(len(final_category_data[species][0]["method"])):
create_plot(
ax,
final_category_data[species][0]["precision"][i],
final_category_data[species][0]["recall"][i],
final_category_data[species][0]["pr"][i],
final_category_data[species][0]["method"][i],
colors[i],
)

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("Precision")
ax.set_ylabel("Recall")
ax.set_title(f"{species.capitalize()}")
ax.legend(loc="lower right")
# calculate AUC values
fpr_list = []
tpr_list = []
threshold_list = []
roc_auc_list = []
precision_list = []
recall_list = []
pr_auc_list = []

for data in methods:
fpr, tpr, threshold, roc_auc = get_roc_data(data)
fpr_list.append(fpr)
tpr_list.append(tpr)
threshold_list.append(threshold)
roc_auc_list.append(roc_auc)

precision, recall, pr_auc = get_pr_data(data)
precision_list.append(precision)
recall_list.append(recall)
pr_auc_list.append(pr_auc)

axes[5].set_visible(False)
fig.suptitle(
"Precision/Recall Curve for All Species w/ Complete Inferred Networks",
fontsize=20,
)
plt.tight_layout()
plt.show()
species_data = {
"fpr": fpr_list,
"tpr": tpr_list,
"roc": roc_auc_list,
"precision": precision_list,
"recall": recall_list,
"pr": pr_auc_list,
"method": ["Overlapping", "Hypergeometric", "Degree", "RW"],
}
final_category_data[species].append(species_data)

# Create a figure with 2 subplots (one for each species)
fig, axes = plt.subplots(
2, 3, figsize=(18, 10)
) # Create a 2x3 grid of subplots
axes = axes.flatten()
colors = ["red", "green", "blue", "orange", "purple"]

for idx, species in enumerate(species_list):
ax = axes[idx] # Get the subplot axis for the current species

for i in range(len(final_category_data[species][0]["method"])):
create_plot(
ax,
final_category_data[species][0]["fpr"][i],
final_category_data[species][0]["tpr"][i],
final_category_data[species][0]["roc"][i],
final_category_data[species][0]["method"][i],
colors[i],
)

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title(f"{species.capitalize()}")
ax.legend(loc="lower right")

axes[5].set_visible(False)
fig.suptitle("ROC Curve for All Species w/ " + subplot_titles[k], fontsize=20)
plt.savefig(Path("./results/images/", f"roc_{subplot_titles[k].lower().replace(" ", "_")}"))
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(
2, 3, figsize=(18, 10)
) # Create a 2x3 grid of subplots
axes = axes.flatten()
colors = ["red", "green", "blue", "orange", "purple"]

for idx, species in enumerate(species_list):
ax = axes[idx] # Get the subplot axis for the current species

for i in range(len(final_category_data[species][0]["method"])):
create_plot(
ax,
final_category_data[species][0]["recall"][i],
final_category_data[species][0]["precision"][i],
final_category_data[species][0]["pr"][i],
final_category_data[species][0]["method"][i],
colors[i],
)

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title(f"{species.capitalize()}")
ax.legend(loc="lower right")

axes[5].set_visible(False)
fig.suptitle(
"Precision/Recall Curve for All Species w/ " + subplot_titles[k],
fontsize=20,
)
plt.tight_layout()
plt.savefig(Path("./results/images/", f"pr_{subplot_titles[k].lower().replace(" ", "_")}"))
plt.show()
k += 1

# generate RW figures

Expand All @@ -178,6 +191,12 @@ def main():
"./results/final-rw-non-inferred-regular/",
"./results/final-rw-non-inferred-pro-go/",
]
subplot_titles = [
"Inferred Complete Network",
"Inferred ProGo Network",
"Non Inferred Complete Network",
"Non Inferred ProGo Network",
]
final_rw_data = defaultdict(list)

# Load data for each directory and species
Expand Down Expand Up @@ -222,10 +241,15 @@ def main():
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title(f"{directory}")
ax.set_title(f"{subplot_titles[idx]}")
ax.legend(loc="lower right")

# Adjust layout and show the plot
fig.suptitle(
"ROC Curve for RandomWalk Configuration",
fontsize=20,
)
plt.savefig(Path("./results/images/rw_roc.png"))
plt.tight_layout()
plt.show()

Expand All @@ -239,21 +263,26 @@ def main():
ax = axs[idx] # Get the corresponding subplot
for i, species in enumerate(species_list):
ax.plot(
final_rw_data[species][idx]["precision"],
final_rw_data[species][idx]["recall"],
final_rw_data[species][idx]["precision"],
color=colors[i],
lw=2,
label=f"{species} (area = %0.2f)" % final_rw_data[species][idx]["roc"],
)

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("Precision")
ax.set_ylabel("Recall")
ax.set_title(f"{directory}")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title(f"{subplot_titles[idx]}")
ax.legend(loc="lower right")

# Adjust layout and show the plot
fig.suptitle(
"Precision/Recall Curve for RandomWalk Configuration",
fontsize=20,
)
plt.savefig(Path("./results/images/rw_pr.png"))
plt.tight_layout()
plt.show()

Expand Down
Binary file added results/images/pr_complete_inferred_networks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/images/rw_pr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/images/rw_roc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit eff16df

Please sign in to comment.