From 688bb952610c834087d905607e3d02f7a02b4899 Mon Sep 17 00:00:00 2001 From: "Alex M. Maldonado" Date: Sun, 19 May 2024 18:37:00 -0400 Subject: [PATCH] refactor: unify pocket computations --- subpex/pcoord/jaccard.py | 47 ------------ subpex/pocket/compute.py | 160 +++++++++++++++++++++++++++++++++++++++ subpex/pocket/props.py | 32 -------- 3 files changed, 160 insertions(+), 79 deletions(-) delete mode 100644 subpex/pcoord/jaccard.py create mode 100644 subpex/pocket/compute.py delete mode 100644 subpex/pocket/props.py diff --git a/subpex/pcoord/jaccard.py b/subpex/pcoord/jaccard.py deleted file mode 100644 index 1ac2c0b..0000000 --- a/subpex/pcoord/jaccard.py +++ /dev/null @@ -1,47 +0,0 @@ -from collections.abc import Sequence - -import scipy as sp - -from ..configs import SubpexConfig - - -def get_jaccard_distance( - fop_ref: Sequence[Sequence[float]], - fop_segment: Sequence[Sequence[float]], - subpex_config: SubpexConfig, -) -> float: - """Calculates the Jaccard distance between the points in `fop_ref` and - the `fop_segment`. Uses the distance between points to calculate the intersection. - - Args: - fop_ref: Reference FOP. - fop_segment: Segment FOP. - subpex_config: The subpex context manager. - - Returns: - Jaccard distance. - """ - # sometimes no points are present in the FOP. - if len(fop_segment) == 0: - return 1.0 - - # Obtaining the trees for both field of points - tree_ref = sp.spatial.cKDTree(fop_ref) - tree_segment = sp.spatial.cKDTree(fop_segment) - - # Obtain the points that are at less than resolution/2.5 (aka have the same - # coordinates) - clash_indices = tree_ref.query_ball_tree( - tree_segment, subpex_config.pocket_resolution / 2.5, p=2, eps=0 - ) - - # Count the points that intersect and convert to float - intersection = float(len([x for x in clash_indices if x])) - - # Obtain the union of both FOP - union = float(len(fop_ref) + len(fop_segment) - intersection) - - # Calculate Jaccard distance - jaccard = 1.0 - intersection / union - - return jaccard diff --git a/subpex/pocket/compute.py b/subpex/pocket/compute.py new file mode 100644 index 0000000..360d19e --- /dev/null +++ b/subpex/pocket/compute.py @@ -0,0 +1,160 @@ +"""Compute properties of pockets.""" + +from typing import Any + +from collections.abc import Sequence + +import MDAnalysis as mda +import numpy as np +import numpy.typing as npt +import scipy as sp + +from ..configs import SubpexConfig +from ..fop.compute import get_fop_inputs +from ..pocket.detect import get_fop_pocket +from ..utils.spatial import calculate_distance_two_points, get_centroid, get_rmsd + + +def get_pocket_rmsd( + atoms_ref: npt.NDArray[np.float64], atoms_frame: npt.NDArray[np.float64] +) -> float: + """Takes the xyz coordinates of a field of points and calculates the radius + of gyration. It assumes a mass of one for all the points. + + Args: + atoms_ref npt.NDArray[np.float64] : Atomic coordinates of reference structure. + atoms_frame npt.NDArray[np.float64] : Atomic coordinates of current frame. + + Returns: + RMSD of atoms captured by the pocket. + """ + rmsd = get_rmsd(atoms_ref, atoms_frame) + return float(rmsd) + + +def get_pocket_rmsd_convenience( + atoms_frame: mda.AtomGroup, + subpex_config: SubpexConfig, + atoms_ref: mda.AtomGroup | None = None, + *args: Any, + **kwargs: Any +) -> float: + """A convenience wrapper around `get_pocket_rmsd` using a simulation frame. + + Args: + atoms_frame: Trajectory frame to analyze. + subpex_config: SuPEx configuration. + """ + if atoms_ref is None: + raise ValueError("atoms_ref cannot be None") + rmsd = get_pocket_rmsd(atoms_ref, atoms_frame) + return rmsd + + +def get_pocket_rog( + fop_pocket: Sequence[Sequence[float]], *args: Any, **kwargs: Any +) -> float: + """Takes the xyz coordinates of a field of points and calculates the radius + of gyration. It assumes a mass of one for all the points. + + Args: + fop_pocket Sequence[Sequence[float]] : The field of points defining the pocket shape. + + Returns: + radius of gyration of the pocket. + """ + mass = len(fop_pocket) + + if mass == 0: + # Sometimes there are no points. + return -9999 + + centroid = get_centroid(fop_pocket) + gyration_radius = 0.0 + for i in fop_pocket: + gyration_radius += (calculate_distance_two_points(i, centroid)) ** 2 + + return float(np.sqrt(gyration_radius / mass)) + + +def get_pocket_rog_convenience( + atoms_frame: mda.AtomGroup, + subpex_config: SubpexConfig, + atoms_ref: mda.AtomGroup | None = None, + *args: Any, + **kwargs: Any +) -> float: + """A convenience wrapper around `get_pocket_rog` using a simulation frame. + + Args: + atoms_frame: Trajectory frame to analyze. + subpex_config: SuPEx configuration. + """ + inputs = get_fop_inputs( + atoms_frame, subpex_config.pocket.selection_str, subpex_config, *args, **kwargs + ) + inputs["fop"] = get_fop_pocket(**inputs) + return get_pocket_rog(**inputs) + + +def get_pocket_jaccard( + fop_ref: Sequence[Sequence[float]], + fop_frame: Sequence[Sequence[float]], + resolution: float, +) -> float: + """Calculates the Jaccard distance between the points in `fop_ref` and + the `fop_segment`. Uses the distance between points to calculate the intersection. + + Args: + fop_ref: Reference FOP. + fop_frame: Segment FOP. + resolution: Field of points resolution. + + Returns: + Jaccard distance. + """ + # sometimes no points are present in the FOP. + if len(fop_frame) == 0: + return 1.0 + + # Obtaining the trees for both field of points + tree_ref = sp.spatial.cKDTree(fop_ref) + tree_segment = sp.spatial.cKDTree(fop_frame) + + # Obtain the points that are at less than resolution/2.5 (aka have the same + # coordinates) + clash_indices = tree_ref.query_ball_tree(tree_segment, resolution / 2.5, p=2, eps=0) + + # Count the points that intersect and convert to float + intersection = float(len([x for x in clash_indices if x])) + + # Obtain the union of both FOP + union = float(len(fop_ref) + len(fop_frame) - intersection) + + # Calculate Jaccard distance + jaccard = 1.0 - intersection / union + + return jaccard + + +def get_pocket_jaccard_convenience( + atoms_frame: mda.AtomGroup, + subpex_config: SubpexConfig, + atoms_ref: mda.AtomGroup | None = None, + fop_ref: Sequence[Sequence[float]] | None = None, + *args: Any, + **kwargs: Any +) -> float: + """A convenience wrapper around `get_pocket_rog` using a simulation frame. + + Args: + atoms_frame: Trajectory frame to analyze. + subpex_config: SuPEx configuration. + """ + if fop_ref is None: + raise ValueError("fop_ref cannot be None") + inputs = get_fop_inputs( + atoms_frame, subpex_config.pocket.selection_str, subpex_config, *args, **kwargs + ) + inputs["fop_frame"] = get_fop_pocket(**inputs) + return get_pocket_jaccard(fop_ref=fop_ref, **inputs) diff --git a/subpex/pocket/props.py b/subpex/pocket/props.py deleted file mode 100644 index 51399df..0000000 --- a/subpex/pocket/props.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Compute properties of pockets.""" - -from collections.abc import Sequence - -import numpy as np - -from ..utils.spatial import calculate_distance_two_points, get_centroid - - -def get_pocket_rog(fop_pocket: Sequence[Sequence[float]]) -> float: - """Takes the xyz coordinates of a field of points and calculates the radius - of gyration. It assumes a mass of one for all the points. - - Args: - fop_pocket: the field of points defining the pocket shape. - - Returns: - radius of gyration of the pocket. - """ - - mass = len(fop_pocket) - - if mass == 0: - # Sometimes there are no points. - return -9999 - - centroid = get_centroid(fop_pocket) - gyration_radius = 0.0 - for i in fop_pocket: - gyration_radius += (calculate_distance_two_points(i, centroid)) ** 2 - - return float(np.sqrt(gyration_radius / mass))