Skip to content

Commit

Permalink
Merge pull request #228 from hmcezar/main
Browse files Browse the repository at this point in the history
Changes in utils scripts
  • Loading branch information
hmcezar authored Sep 6, 2024
2 parents 58018f4 + feea5af commit 0670bab
Show file tree
Hide file tree
Showing 6 changed files with 1,152 additions and 90 deletions.
6 changes: 3 additions & 3 deletions pytest-mpi
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ else
fi

if [ "${ORDEROUTPUT}" -eq 1 ]; then
mpirun -n ${NPROCS} --output-filename outtest pytest ${PYTEST_ARGS} --only-mpi >/dev/null
mpirun -n ${NPROCS} --oversubscribe --output-filename outtest pytest ${PYTEST_ARGS} --only-mpi >/dev/null
exit_code=$?
for ((RANK=0; RANK<${NPROCS}; RANK++)); do
RANKFILE="outtest/1/rank.${RANK}/stdout"
Expand All @@ -94,10 +94,10 @@ else
done
rm -r outtest
elif [ "${UNMUTE}" -eq -1 ]; then
mpirun -n ${NPROCS} pytest ${PYTEST_ARGS} --only-mpi
mpirun -n ${NPROCS} --oversubscribe pytest ${PYTEST_ARGS} --only-mpi
exit_code=$?
else
mpirun -n ${NPROCS} utils/mute_all_ranks_except.sh ${UNMUTE} pytest ${PYTEST_ARGS} --only-mpi
mpirun -n ${NPROCS} --oversubscribe utils/mute_all_ranks_except.sh ${UNMUTE} pytest ${PYTEST_ARGS} --only-mpi
exit_code=$?
fi
fi
Expand Down
71 changes: 53 additions & 18 deletions utils/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# PYTHON_ARGCOMPLETE_OK

import MDAnalysis as mda
from MDAnalysis.analysis import distances
import numpy as np
Expand All @@ -6,7 +8,7 @@
from matplotlib.ticker import MaxNLocator
import os
import sys
import argparse
import argcomplete, argparse
from tqdm import tqdm
import warnings
import dask
Expand Down Expand Up @@ -134,6 +136,8 @@ def aggregates_clustering(
plot_dendrograms,
traj_in_memory,
save_solvent,
font_size,
caption_font_size,
summary_fig_size=(12, 8),
):
u = mda.Universe(grofile, h5mdfile, in_memory=traj_in_memory)
Expand Down Expand Up @@ -182,21 +186,19 @@ def aggregates_clustering(

n_clusters = []
all_sizes = []
total_clust_by_size = {}
clusters_per_frame = []
for c in clusters[0]:
# get the number of clusters and sizes
unique_clusts, clust_counts = np.unique(c, return_counts=True)
for size in clust_counts:
if size not in total_clust_by_size:
total_clust_by_size[size] = 1
else:
total_clust_by_size[size] += 1

n_clusters.append(len(unique_clusts))

clusters_per_frame.append(np.sort(clust_counts))
all_sizes += clust_counts.tolist()

# based on cluster sizes get occurence of each size
sizes, freq = np.unique(all_sizes, return_counts=True)
total_clust_by_size = dict(zip(sizes, freq))
freq = freq / len(u.trajectory[skip:end:stride])

# overall average number of aggregates
Expand Down Expand Up @@ -234,32 +236,50 @@ def aggregates_clustering(
for s, p in prob_mol_size.items():
of.write(f"{s}\t{p}\n")

of.write("\nClusters sizes in each analyzed frame:\n")
for c in clusters_per_frame:
for s in c:
of.write(f"{s}\t")
of.write("\n")

# plot results
_, axs = plt.subplots(2, 2, figsize=summary_fig_size)
plt.rcParams.update({"font.size": font_size})

axs[0, 0].text(-0.2, 1.05, "(a)", transform=axs[0, 0].transAxes, fontsize=caption_font_size, va='top', ha='right')
axs[0, 1].text(-0.2, 1.05, "(b)", transform=axs[0, 1].transAxes, fontsize=caption_font_size, va='top', ha='right')
axs[1, 0].text(-0.2, 1.05, "(c)", transform=axs[1, 0].transAxes, fontsize=caption_font_size, va='top', ha='right')
axs[1, 1].text(-0.2, 1.05, "(d)", transform=axs[1, 1].transAxes, fontsize=caption_font_size, va='top', ha='right')

axs[0, 0].plot(frames, n_clusters)
axs[0, 0].axhline(avg_n_aggs, linestyle="--")
axs[0, 0].set_ylabel("Num. of aggregates")
axs[0, 0].set_xlabel("Frame")
axs[0, 0].set_ylabel("Num. of aggregates", fontsize=font_size)
axs[0, 0].set_xlabel("Frame", fontsize=font_size)
# axs[0, 0].set_ylim([4, 6])
axs[0, 0].yaxis.set_major_locator(MaxNLocator(integer=True))

xticklabels = [f"{sizes[i]}" for i in range(len(sizes))]
axs[0, 1].bar(xticklabels, freq, width=0.8)
axs[0, 1].set_ylabel("Avg. num. per snapshot")
axs[0, 1].set_xlabel("Aggregate size")
axs[0, 1].tick_params("x", labelrotation=60)
axs[0, 1].set_ylabel("Avg. num. per snapshot", fontsize=font_size)
axs[0, 1].set_xlabel("Aggregate size", fontsize=font_size)
axs[0, 1].tick_params("x", labelrotation=60, labelsize=font_size)

xticklabels = [f"{k}" for k in prob_by_size.keys()]
axs[1, 0].bar(xticklabels, prob_by_size.values(), width=0.8)
axs[1, 0].set_ylabel("Prob.")
axs[1, 0].set_xlabel("Aggregate size")
axs[1, 0].tick_params("x", labelrotation=60)
axs[1, 0].set_ylabel("Prob.", fontsize=font_size)
axs[1, 0].set_xlabel("Aggregate size", fontsize=font_size)
axs[1, 0].tick_params("x", labelrotation=60, labelsize=font_size)

xticklabels = [f"{k}" for k in prob_mol_size.keys()]
axs[1, 1].bar(xticklabels, prob_mol_size.values(), width=0.8)
axs[1, 1].set_ylabel("Prob. molecule in agg.")
axs[1, 1].set_xlabel("Aggregate size")
axs[1, 1].tick_params("x", labelrotation=60)
axs[1, 1].set_ylabel("Prob. molecule in agg.", fontsize=font_size)
axs[1, 1].set_xlabel("Aggregate size", fontsize=font_size)
axs[1, 1].tick_params("x", labelrotation=60, labelsize=font_size)

axs[0, 0].tick_params(axis='both', which='major', labelsize=font_size)
axs[0, 1].tick_params(axis='both', which='major', labelsize=font_size)
axs[1, 0].tick_params(axis='both', which='major', labelsize=font_size)
axs[1, 1].tick_params(axis='both', which='major', labelsize=font_size)

plt.tight_layout()
plt.savefig("summary_clustering.pdf", bbox_inches="tight")
Expand Down Expand Up @@ -351,13 +371,26 @@ def aggregates_clustering(
default=(8, 6),
help="two integers to define the size of the summary figure (default = 8 6)",
)
parser.add_argument(
"--font-size",
type=int,
default=12,
help="font size for the summary figure (default = 12)",
)
parser.add_argument(
"--caption-font-size",
type=int,
default=13,
help="font size for the captions of summary figure (default = 13)",
)
parser.add_argument(
"--traj-in-memory",
action="store_true",
default=False,
help="load the whole trajectory in memory with MDAanalysis",
)

argcomplete.autocomplete(parser)
args = parser.parse_args()

if os.path.splitext(args.h5md_file)[1].lower() == ".h5":
Expand Down Expand Up @@ -404,5 +437,7 @@ def aggregates_clustering(
args.plot_dendrograms,
args.traj_in_memory,
args.save_solvent,
args.font_size,
args.caption_font_size,
tuple(args.summary_fig_size),
)
Loading

0 comments on commit 0670bab

Please sign in to comment.