Skip to content

Commit

Permalink
Add option to font sizes in aggregates.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hmcezar committed Feb 28, 2024
1 parent f8fc206 commit 78a18a2
Showing 1 changed file with 43 additions and 12 deletions.
55 changes: 43 additions & 12 deletions utils/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# PYTHON_ARGCOMPLETE_OK

import MDAnalysis as mda
from MDAnalysis.analysis import distances
import numpy as np
Expand All @@ -6,7 +8,7 @@
from matplotlib.ticker import MaxNLocator
import os
import sys
import argparse
import argcomplete, argparse
from tqdm import tqdm
import warnings
import dask
Expand Down Expand Up @@ -134,6 +136,8 @@ def aggregates_clustering(
plot_dendrograms,
traj_in_memory,
save_solvent,
font_size,
caption_font_size,
summary_fig_size=(12, 8),
):
u = mda.Universe(grofile, h5mdfile, in_memory=traj_in_memory)
Expand Down Expand Up @@ -236,30 +240,42 @@ def aggregates_clustering(

# plot results
_, axs = plt.subplots(2, 2, figsize=summary_fig_size)
plt.rcParams.update({"font.size": font_size})

axs[0, 0].text(-0.2, 1.05, "(a)", transform=axs[0, 0].transAxes, fontsize=caption_font_size, va='top', ha='right')
axs[0, 1].text(-0.2, 1.05, "(b)", transform=axs[0, 1].transAxes, fontsize=caption_font_size, va='top', ha='right')
axs[1, 0].text(-0.2, 1.05, "(c)", transform=axs[1, 0].transAxes, fontsize=caption_font_size, va='top', ha='right')
axs[1, 1].text(-0.2, 1.05, "(d)", transform=axs[1, 1].transAxes, fontsize=caption_font_size, va='top', ha='right')

axs[0, 0].plot(frames, n_clusters)
axs[0, 0].axhline(avg_n_aggs, linestyle="--")
axs[0, 0].set_ylabel("Num. of aggregates")
axs[0, 0].set_xlabel("Frame")
axs[0, 0].set_ylabel("Num. of aggregates", fontsize=font_size)
axs[0, 0].set_xlabel("Frame", fontsize=font_size)
# axs[0, 0].set_ylim([4, 6])
axs[0, 0].yaxis.set_major_locator(MaxNLocator(integer=True))

xticklabels = [f"{sizes[i]}" for i in range(len(sizes))]
axs[0, 1].bar(xticklabels, freq, width=0.8)
axs[0, 1].set_ylabel("Avg. num. per snapshot")
axs[0, 1].set_xlabel("Aggregate size")
axs[0, 1].tick_params("x", labelrotation=60)
axs[0, 1].set_ylabel("Avg. num. per snapshot", fontsize=font_size)
axs[0, 1].set_xlabel("Aggregate size", fontsize=font_size)
axs[0, 1].tick_params("x", labelrotation=60, labelsize=font_size)

xticklabels = [f"{k}" for k in prob_by_size.keys()]
axs[1, 0].bar(xticklabels, prob_by_size.values(), width=0.8)
axs[1, 0].set_ylabel("Prob.")
axs[1, 0].set_xlabel("Aggregate size")
axs[1, 0].tick_params("x", labelrotation=60)
axs[1, 0].set_ylabel("Prob.", fontsize=font_size)
axs[1, 0].set_xlabel("Aggregate size", fontsize=font_size)
axs[1, 0].tick_params("x", labelrotation=60, labelsize=font_size)

xticklabels = [f"{k}" for k in prob_mol_size.keys()]
axs[1, 1].bar(xticklabels, prob_mol_size.values(), width=0.8)
axs[1, 1].set_ylabel("Prob. molecule in agg.")
axs[1, 1].set_xlabel("Aggregate size")
axs[1, 1].tick_params("x", labelrotation=60)
axs[1, 1].set_ylabel("Prob. molecule in agg.", fontsize=font_size)
axs[1, 1].set_xlabel("Aggregate size", fontsize=font_size)
axs[1, 1].tick_params("x", labelrotation=60, labelsize=font_size)

axs[0, 0].tick_params(axis='both', which='major', labelsize=font_size)
axs[0, 1].tick_params(axis='both', which='major', labelsize=font_size)
axs[1, 0].tick_params(axis='both', which='major', labelsize=font_size)
axs[1, 1].tick_params(axis='both', which='major', labelsize=font_size)

plt.tight_layout()
plt.savefig("summary_clustering.pdf", bbox_inches="tight")
Expand Down Expand Up @@ -351,13 +367,26 @@ def aggregates_clustering(
default=(8, 6),
help="two integers to define the size of the summary figure (default = 8 6)",
)
parser.add_argument(
"--font-size",
type=int,
default=12,
help="font size for the summary figure (default = 12)",
)
parser.add_argument(
"--caption-font-size",
type=int,
default=13,
help="font size for the captions of summary figure (default = 13)",
)
parser.add_argument(
"--traj-in-memory",
action="store_true",
default=False,
help="load the whole trajectory in memory with MDAanalysis",
)

argcomplete.autocomplete(parser)
args = parser.parse_args()

if os.path.splitext(args.h5md_file)[1].lower() == ".h5":
Expand Down Expand Up @@ -404,5 +433,7 @@ def aggregates_clustering(
args.plot_dendrograms,
args.traj_in_memory,
args.save_solvent,
args.font_size,
args.caption_font_size,
tuple(args.summary_fig_size),
)

0 comments on commit 78a18a2

Please sign in to comment.