Skip to content

Commit

Permalink
Improved output of aggregates
Browse files Browse the repository at this point in the history
  • Loading branch information
hmcezar committed Sep 26, 2023
1 parent 33ab182 commit 6fe8d4c
Showing 1 changed file with 70 additions and 12 deletions.
82 changes: 70 additions & 12 deletions utils/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import os
import sys
import argparse
from tqdm import tqdm
import warnings
Expand Down Expand Up @@ -114,6 +115,7 @@ def compute_clusters(

print_sel.write(f"./colored_pdbs/snap_{frame}.pdb")

plt.close()
return clusters


Expand All @@ -132,6 +134,7 @@ def aggregates_clustering(
plot_dendrograms,
traj_in_memory,
save_solvent,
summary_fig_size=(12, 8),
):
u = mda.Universe(grofile, h5mdfile, in_memory=traj_in_memory)

Expand Down Expand Up @@ -178,38 +181,85 @@ def aggregates_clustering(
clusters = dask.compute(job_list, num_workers=nworkers)

n_clusters = []
clust_sizes = []
all_sizes = []
total_clust_by_size = {}
for c in clusters[0]:
# get the number of clusters and sizes
unique_clusts, clust_counts = np.unique(c, return_counts=True)
for size in clust_counts:
if size not in total_clust_by_size:
total_clust_by_size[size] = 1
else:
total_clust_by_size[size] += 1
n_clusters.append(len(unique_clusts))

clust_sizes.append(clust_counts)
all_sizes += clust_counts.tolist()

# based on cluster sizes get occurence of each size
sizes, freq = np.unique(all_sizes, return_counts=True)
freq = freq / len(u.trajectory[skip:end:stride])

# write sizes and freq to file
# overall average number of aggregates
avg_n_aggs = np.average(n_clusters)

# compute probability of picking cluster of size n
prob_by_size = {}
for k, v in total_clust_by_size.items():
prob_by_size[k] = v / np.sum(n_clusters)
prob_by_size = dict(sorted(prob_by_size.items()))

# compute probability of picking a random molecule and
# it belonging to a cluster of size n
prob_mol_size = {}
norm = len(at_sel) * len(u.trajectory[skip:end:stride])
for k, v in total_clust_by_size.items():
prob_mol_size[k] = k * v / norm
prob_mol_size = dict(sorted(prob_mol_size.items()))

# write summary to file
with open("summary_clustering.dat", "w") as of:
of.write("Executed command: " + " ".join(sys.argv) + "\n")

of.write(f"\nAverage number of aggregates: {avg_n_aggs}\n")

of.write("\nsize\tfrequency\n")
for s, f in zip(sizes, freq):
of.write(f"{s}\t{f}\n")

of.write("\nsize\tprobability\n")
for s, p in prob_by_size.items():
of.write(f"{s}\t{p}\n")

of.write("\nsize\tprob molecule\n")
for s, p in prob_mol_size.items():
of.write(f"{s}\t{p}\n")

# plot results
fig, (ax1, ax2) = plt.subplots(2, 1)
_, axs = plt.subplots(2, 2, figsize=summary_fig_size)

ax1.plot(frames, n_clusters)
ax1.set_ylabel("Number of aggregates")
ax1.set_xlabel("Frame")
ax1.yaxis.set_major_locator(MaxNLocator(integer=True))
axs[0, 0].plot(frames, n_clusters)
axs[0, 0].axhline(avg_n_aggs, linestyle="--")
axs[0, 0].set_ylabel("Num. of aggregates")
axs[0, 0].set_xlabel("Frame")
axs[0, 0].yaxis.set_major_locator(MaxNLocator(integer=True))

xticklabels = [f"{sizes[i]}" for i in range(len(sizes))]
ax2.bar(xticklabels, freq, width=0.8)
ax2.set_ylabel("Frequency")
ax2.set_xlabel("Aggregate size")
ax2.tick_params("x", labelrotation=60)
axs[0, 1].bar(xticklabels, freq, width=0.8)
axs[0, 1].set_ylabel("Avg. num. per snapshot")
axs[0, 1].set_xlabel("Aggregate size")
axs[0, 1].tick_params("x", labelrotation=60)

xticklabels = [f"{k}" for k in prob_by_size.keys()]
axs[1, 0].bar(xticklabels, prob_by_size.values(), width=0.8)
axs[1, 0].set_ylabel("Prob.")
axs[1, 0].set_xlabel("Aggregate size")
axs[1, 0].tick_params("x", labelrotation=60)

xticklabels = [f"{k}" for k in prob_mol_size.keys()]
axs[1, 1].bar(xticklabels, prob_mol_size.values(), width=0.8)
axs[1, 1].set_ylabel("Prob. molecule in agg.")
axs[1, 1].set_xlabel("Aggregate size")
axs[1, 1].tick_params("x", labelrotation=60)

plt.tight_layout()
plt.savefig("summary_clustering.pdf", bbox_inches="tight")
Expand Down Expand Up @@ -294,6 +344,13 @@ def aggregates_clustering(
default=False,
help="plot the dendrograms (saved in ./dendrograms) (use with stride because its ~10x slower)",
)
parser.add_argument(
"--summary-fig-size",
type=int,
nargs=2,
default=(8, 6),
help="two integers to define the size of the summary figure (default = 8 6)",
)
parser.add_argument(
"--traj-in-memory",
action="store_true",
Expand Down Expand Up @@ -347,4 +404,5 @@ def aggregates_clustering(
args.plot_dendrograms,
args.traj_in_memory,
args.save_solvent,
tuple(args.summary_fig_size),
)

0 comments on commit 6fe8d4c

Please sign in to comment.