Skip to content

Commit

Permalink
metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed May 3, 2024
1 parent 3e4ffd0 commit c019c94
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/beignet/center_of_mass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch


def compute_center_of_mass(traj):
"""Compute the center of mass for each frame.
Parameters
----------
traj : Trajectory
Trajectory to compute center of mass for
Returns
-------
com : torch.Tensor, shape=(n_frames, 3)
Coordinates of the center of mass for each frame
"""

com = torch.empty((traj.n_frames, 3))

masses = torch.tensor([a.element.mass for a in traj.top.atoms])
masses /= masses.sum()

xyz = traj.xyz

for i, x in enumerate(xyz):
com[i, :] = torch.tensordot(masses, x.double().t(), dims=0)

return com
55 changes: 55 additions & 0 deletions src/beignet/gyration_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch


def _compute_center_of_geometry(traj):
"""Compute the center of geometry for each frame.
Parameters
----------
traj : Trajectory
Trajectory to compute center of geometry for.
Returns
-------
centers : torch.Tensor, shape=(n_frames, 3)
Coordinates of the center of geometry for each frame.
"""

centers = torch.zeros((traj.n_frames, 3))

for i, x in enumerate(traj.xyz):
centers[i, :] = torch.mean(x.double().t(), dim=1)

return centers


def gyration_tensor(traj):
"""Compute the gyration tensor of a trajectory.
For every frame,
.. math::
S_{xy} = \sum_{i_atoms} r^{i}_x r^{i}_y
Parameters
----------
traj : Trajectory
Trajectory to compute gyration tensor of.
Returns
-------
S_xy: torch.Tensor, shape=(traj.n_frames, 3, 3), dtype=float64
Gyration tensors for each frame.
References
----------
.. [1] https://isg.nist.gov/deepzoomweb/measurement3Ddata_help#shape-metrics-formulas
"""
center_of_geom = torch.unsqueeze(_compute_center_of_geometry(traj), dim=1)

xyz = traj.xyz - center_of_geom

return torch.einsum('...ji,...jk->...ik', xyz, xyz) / traj.n_atoms
60 changes: 60 additions & 0 deletions src/beignet/rmsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from scipy.spatial.transform import Rotation as R


# TODO (isaacsoh) parallelize and speed up, eliminate 3-D requirement
def _rmsd(traj1, traj2):
"""
Parameters
----------
traj1 : Trajectory
For each conformation in this trajectory, compute the RMSD to
a particular 'reference' conformation in another trajectory.
traj2 : Trajectory
The reference conformation to measure distances
to.
Returns
-------
rmsd_result : torch.Tensor
The rmsd calculation of two trajectories.
"""

assert traj1.shape == traj2.shape, "Input tensors must have the same shape"
assert traj1.dim() == 3, "Input tensors must be 3-D (num_frames, num_atoms, 3)"

num_frames = traj1.shape[0] # Number of frames

# Center the trajectories
traj1 = traj1 - traj1.mean(dim=1, keepdim=True)
traj2 = traj2 - traj2.mean(dim=1, keepdim=True)

# Initialization of the resulting RMSD tensor
rmsd_result = torch.zeros(num_frames).double()

for i in range(num_frames):
# For each configuration compute the rotation matrix minimizing RMSD using SVD
u, s, v = torch.svd(torch.mm(traj1[i].t(), traj2[i]))

# Determinat of u * v
d = (u * v).det().item() < 0.0

if d:
s[-1] = s[-1] * (-1)
u[:, -1] = u[:, -1] * (-1)

# Optimal rotation matrix
rot_matrix = torch.mm(v, u.t())

test = (R.from_matrix(rot_matrix)).as_matrix()

assert torch.allclose(torch.from_numpy(test), rot_matrix, rtol=1e-03, atol=1e-04)

# Calculate RMSD and append to resulting tensor
traj2[i] = torch.mm(traj2[i], rot_matrix)

rmsd_result[i] = torch.sqrt(
torch.sum((traj1[i] - traj2[i]) ** 2) / traj1.shape[1])

return rmsd_result

0 comments on commit c019c94

Please sign in to comment.