Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes in utils scripts #228

Merged
merged 22 commits into from
Sep 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add name selection in center_group.py
  • Loading branch information
Manuel Carrer committed Feb 2, 2024
commit 91d887c6e45171a2f9eabeacd93dc168efce00a4
190 changes: 157 additions & 33 deletions utils/center_group.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import argparse
import os
import re
import sys

import h5py
import numpy as np
import re


# based on https://stackoverflow.com/a/6512463/3254658
Expand All @@ -22,6 +24,19 @@ def parse_bead_list(string):
)
return list(range(start, end + 1))

def parse_name_list(selection, h5md_file):
f_in = h5py.File(h5md_file, "r")
names = f_in["parameters/vmd_structure/name"][:]
species = f_in["particles/all/species"][:]
name_to_species = {n: s for s, n in enumerate(names)}

name_list = np.empty(0, dtype=np.int32)
for name in selection:
name_list = np.append(name_list, np.where(species == name_to_species[np.bytes_(name)]))

f_in.close()
return name_list


def get_centers(positions, box, nrefs=3):
centers = np.empty((0, positions.shape[2]))
Expand Down Expand Up @@ -54,23 +69,7 @@ def get_centers(positions, box, nrefs=3):
return centers


def center_trajectory(
h5md_file, bead_list, nrefbeads, overwrite=False, out_path=None, center_last=False
):
if out_path is None:
out_path = os.path.join(
os.path.abspath(os.path.dirname(h5md_file)),
os.path.splitext(os.path.split(h5md_file)[-1])[0]
+ "_new"
+ os.path.splitext(os.path.split(h5md_file)[-1])[1],
)
if os.path.exists(out_path) and not overwrite:
error_str = (
f"The specified output file {out_path} already exists. "
f'use overwrite=True ("-f" flag) to overwrite.'
)
raise FileExistsError(error_str)

def center_trajectory_mic(h5md_file, bead_list, nrefbeads, out_path, center_last=False):
f_in = h5py.File(h5md_file, "r")
f_out = h5py.File(out_path, "w")

Expand All @@ -81,7 +80,6 @@ def center_trajectory(
f_in.copy(k, f_out)

box_size = f_in["particles/all/box/edges/value"][:]

beads_pos = f_in["particles/all/position/value"][:][:, bead_list, :]
centers = get_centers(beads_pos, box_size, nrefbeads)

Expand All @@ -108,6 +106,68 @@ def center_trajectory(
f_in.close()
f_out.close()

def center_of_mass(pos, n, box):
p_mapped = 2 * np.pi * pos / box
cos_p_mapped = np.cos(p_mapped)
sin_p_mapped = np.sin(p_mapped)

cos_average = np.sum(cos_p_mapped) / n
sin_average = np.sum(sin_p_mapped) / n

theta = np.arctan2(-sin_average, -cos_average) + np.pi
return box * theta / (2 * np.pi)


def get_centers_com(positions, box_size, axis):
frames = positions.shape[0]
n = positions.shape[1]
centers = np.zeros((frames, 3))
for frame in range(frames):
centers[frame, axis] = center_of_mass(positions[frame, :, axis], n, box_size[frame, axis])
return centers


def center_trajectory_com(h5md_file, bead_list, axis, out_path):
f_in = h5py.File(h5md_file, "r")
f_out = h5py.File(out_path, "w")

for k, v in f_in.attrs.items():
f_out.attrs[k] = v

for k in f_in.keys():
f_in.copy(k, f_out)
f_in.close()

box_size = f_out["particles/all/box/edges/value"][:]
beads_pos = f_out["particles/all/position/value"][:][:, bead_list, :]

box_diag = np.diag(box_size[0])
for frame in range(1, box_size.shape[0]):
box_diag = np.vstack((box_diag, np.diag(box_size[frame])))

centers = get_centers_com(beads_pos, box_diag, axis)

mask = np.eye(1, 3, k=axis)
box_translate = 0.5 * mask * box_diag

translate = box_translate - centers
tpos = f_out["particles/all/position/value"] + np.repeat(
translate[:, np.newaxis, :],
f_out["particles/all/position/value"].shape[1],
axis=1,
)

tpos = np.mod(
tpos[:, :, :], np.repeat(box_diag[:, np.newaxis, :], tpos.shape[1], axis=1)
)

f_out["particles/all/position/value"][:] = tpos - np.repeat(
box_translate[:, np.newaxis, :],
f_out["particles/all/position/value"].shape[1],
axis=1,
)
f_out.close()


if __name__ == "__main__":
description = (
Expand All @@ -120,9 +180,35 @@ def center_trajectory(
"--beads",
type=parse_bead_list,
nargs="+",
required=True,
default=None,
help="bead list to center (e.g.: 1-100 102-150)",
)
parser.add_argument(
"-n",
"--names",
type=str,
nargs="+",
default=None,
help="Names of bead to center (e.g.: C1 N G2)",
)
parser.add_argument(
"-m",
"--method",
choices=["COM", "MIC"],
type=str,
default=None,
help="Specify the centering method. Available methods:\n"
"COM: center the box around the center of mass of the given groups along a direction (axis)\n"
"MIC: center the box around a centroid of the group that assures the minimal image convention",
)
parser.add_argument(
"--axis",
type=int,
choices=[0, 1, 2],
default=2,
required='COM' in sys.argv,
help="Direction along which to calculate the center of mass."
)
parser.add_argument(
"-o",
"--out",
Expand Down Expand Up @@ -155,17 +241,55 @@ def center_trajectory(
)
args = parser.parse_args()

bead_list = []
for interval in args.beads:
bead_list += interval
if args.beads is None and args.names is None:
error_str = "Either the 'beads' or 'names' variable must be provided."
raise ValueError(error_str)

if args.out_path is None:
args.out_path = os.path.join(
os.path.abspath(os.path.dirname(args.h5md_file)),
os.path.splitext(os.path.split(args.h5md_file)[-1])[0]
+ "_new"
+ os.path.splitext(os.path.split(args.h5md_file)[-1])[1],
)
if os.path.exists(args.out_path) and not args.force:
error_str = (
f"The specified output file {args.out_path} already exists. "
f'use overwrite=True ("-f" flag) to overwrite.'
)
raise FileExistsError(error_str)

if args.beads is not None:
bead_list = []
for interval in args.beads:
bead_list += interval

bead_list = np.array(sorted(bead_list)) - 1
bead_list = np.array(sorted(bead_list)) - 1

if args.names is not None:
name_list = parse_name_list(args.names, args.h5md_file)

if args.beads is not None and args.names is not None:
atom_list = np.intersect1d(bead_list, name_list)
elif args.beads is not None:
atom_list = bead_list
else:
atom_list = name_list

if args.method == "COM":
center_trajectory_com(
args.h5md_file,
atom_list,
args.axis,
args.out_path,
)

if args.method == "MIC":
center_trajectory_mic(
args.h5md_file,
atom_list,
nrefbeads=args.nrefs,
out_path=args.out_path,
center_last=args.center_last,
)

center_trajectory(
args.h5md_file,
bead_list,
nrefbeads=args.nrefs,
overwrite=args.force,
out_path=args.out_path,
center_last=args.center_last,
)