Skip to content

Commit

Permalink
Add option to save solvent beads
Browse files Browse the repository at this point in the history
  • Loading branch information
hmcezar committed Jul 24, 2023
1 parent 8697dd0 commit 66bb6a2
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions utils/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def explore_methods(


def compute_clusters(
not_w, at_sel, ts, cutoff, linkage_method, save_snaps, plot_dendrograms
print_sel, at_sel, ts, cutoff, linkage_method, save_snaps, plot_dendrograms
):
# get CM for each mol (assumes each ResID is a mol)
mol_cms = []
Expand All @@ -84,7 +84,7 @@ def compute_clusters(
np.array(mol_cms), box=ts.dimensions, result=cond_distmat, backend="OpenMP"
)

# get dendrogram and plot
# get dendrogram
Z = hcl.linkage(cond_distmat, linkage_method)

if plot_dendrograms:
Expand All @@ -99,12 +99,20 @@ def compute_clusters(
plt.axhline(cutoff, linestyle="--")
plt.savefig(f"./dendrograms/snap_{ts.frame}.pdf", bbox_inches="tight")

# build the clusters and print them to file
# build the clusters
clusters = hcl.fcluster(Z, cutoff, criterion="distance")

if save_snaps:
not_w.residues.resids = clusters
not_w.write(f"./colored_pdbs/snap_{ts.frame}.pdb")
if len(print_sel.residues.resids) > len(clusters):
resids = np.full((len(print_sel.residues.resids)), np.max(clusters) + 1)
resids[:len(clusters)] = clusters
elif len(print_sel.residues.resids) == len(clusters):
resids = clusters
else:
raise AssertionError("Something is wrong with your selection")
print_sel.residues.resids = resids

print_sel.write(f"./colored_pdbs/snap_{ts.frame}.pdb")

return clusters

Expand All @@ -122,12 +130,18 @@ def aggregates_clustering(
save_snaps,
plot_dendrograms,
traj_in_memory,
save_solvent,
):
u = mda.Universe(grofile, h5mdfile, in_memory=traj_in_memory)

# select atoms that are not the solvent
not_w = u.select_atoms("not name " + solvent_name)

if save_solvent:
print_sel = u.select_atoms("all")
else:
print_sel = not_w

# get CM for each mol (assumes each ResID is a mol)
at_sel = []
for mol in not_w.split("residue"):
Expand All @@ -148,7 +162,7 @@ def aggregates_clustering(
job_list.append(
dask.delayed(
compute_clusters(
not_w,
print_sel,
at_sel,
ts,
cutoff,
Expand Down Expand Up @@ -188,8 +202,8 @@ def aggregates_clustering(
ax2.bar(xticklabels, freq, width=0.8)
ax2.set_ylabel("Frequency")
ax2.set_xlabel("Aggregate size")
ax2.tick_params('x', labelrotation=60)
ax2.tick_params("x", labelrotation=60)

plt.tight_layout()
plt.show()

Expand Down Expand Up @@ -254,6 +268,12 @@ def aggregates_clustering(
default=False,
help="save snapshots with ResID = cluster number to be colored by ResID in VMD (saved in ./colored_pdbs)",
)
parser.add_argument(
"--save-solvent",
action="store_true",
default=False,
help="when using --save-colored-snap, also save solvent beads in the .pdb",
)
parser.add_argument(
"--plot-dendrograms",
action="store_true",
Expand All @@ -269,6 +289,11 @@ def aggregates_clustering(

args = parser.parse_args()

if os.path.splitext(args.h5md_file)[1] != ".h5md":
raise AssertionError(
"Trajectory extension should be .h5md. If you are using .H5 please rename it."
)

if args.linkage_method not in [
"single",
"complete",
Expand Down Expand Up @@ -306,4 +331,5 @@ def aggregates_clustering(
args.save_colored_snap,
args.plot_dendrograms,
args.traj_in_memory,
args.save_solvent
)

0 comments on commit 66bb6a2

Please sign in to comment.