diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..bbdd2e8098 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - hooks: + - id: "check-toml" + - id: "check-yaml" + repo: "https://github.com/pre-commit/pre-commit-hooks" + rev: "v4.6.0" + - hooks: + - args: + - "--fix" + id: "ruff" + - id: "ruff-format" + repo: "https://github.com/astral-sh/ruff-pre-commit" + rev: "v0.3.7" diff --git a/pyproject.toml b/pyproject.toml index 7f9d47afbf..2a13cfd74e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,30 @@ name = "beignet" readme = "README.md" requires-python = ">=3.10" -[tool.ruff.format] -docstring-code-format = true +[tool.ruff] +exclude = [ + "./src/beignet/constants/_adjacent_residue_phi_cosine.py", + "./src/beignet/constants/_adjacent_residue_psi_cosine.py", + "./src/beignet/constants/_amino_acid_1.py", + "./src/beignet/constants/_amino_acid_1_to_amino_acid_3.py", + "./src/beignet/constants/_amino_acid_3.py", + "./src/beignet/constants/_amino_acid_3_to_amino_acid_1.py", + "./src/beignet/constants/_amino_acid_3_to_atom_14.py", +] + +[tool.ruff.lint] +select = [ + "B", # FLAKE8-BUGBEAR + "E", # PYCODESTYLE ERRORS + "F", # PYFLAKES + "I", # ISORT + "W", # PYCODESTYLE WARNINGS +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = [ + "F401", # MODULE IMPORTED BUT UNUSED +] [tool.setuptools_scm] local_scheme = "no-local-version" diff --git a/src/beignet/__init__.py b/src/beignet/__init__.py index 8e489b9a12..d0279879ba 100644 --- a/src/beignet/__init__.py +++ b/src/beignet/__init__.py @@ -5,3 +5,15 @@ __version__ = importlib.metadata.version("beignet") except PackageNotFoundError: __version__ = None + +from ._apply_rotation_matrix import apply_rotation_matrix +from ._apply_transform import apply_transform +from ._invert_rotation_matrix import invert_rotation_matrix +from ._invert_transform import invert_transform + +__all__ = [ + "apply_rotation_matrix", + "apply_transform", + "invert_rotation_matrix", + "invert_transform", +] diff --git a/src/beignet/_apply_rotation_matrix.py b/src/beignet/_apply_rotation_matrix.py new file mode 100644 index 0000000000..fcac443401 --- /dev/null +++ b/src/beignet/_apply_rotation_matrix.py @@ -0,0 +1,47 @@ +import torch +from torch import Tensor + + +def apply_rotation_matrix( + input: Tensor, + rotation: Tensor, + inverse: bool | None = False, +) -> Tensor: + r""" + Rotates vectors in three-dimensional space using rotation matrices. + + Note + ---- + This function interprets the rotation of the original frame to the final + frame as either a projection, where it maps the components of vectors from + the final frame to the original frame, or as a physical rotation, + integrating the vectors into the original frame during the rotation + process. Consequently, the vector components are maintained in the original + frame’s perspective both before and after the rotation. + + Parameters + ---------- + input : Tensor, shape (..., 3) + Each vector represents a vector in three-dimensional space. The number + of rotation matrices and number of vectors must follow standard + broadcasting rules: either one of them equals unity or they both equal + each other. + + rotation : Tensor, shape (..., 3, 3) + Rotation matrices. + + inverse : bool, optional + If `True` the inverse of the rotation matrices are applied to the input + vectors. Default, `False`. + + Returns + ------- + rotated_vectors : Tensor, shape (..., 3) + Rotated vectors. + """ + if inverse: + output = torch.einsum("ikj, ik -> ij", rotation, input) + else: + output = torch.einsum("ijk, ik -> ij", rotation, input) + + return output diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py new file mode 100644 index 0000000000..5037852e50 --- /dev/null +++ b/src/beignet/_apply_transform.py @@ -0,0 +1,54 @@ +from torch import Tensor + +from beignet.operators import apply_rotation_matrix + + +def apply_transform( + input: Tensor, + transform: (Tensor, Tensor), + inverse: bool | None = False, +) -> Tensor: + r""" + Applies three-dimensional transforms to vectors. + + Note + ---- + This function interprets the rotation of the original frame to the final + frame as either a projection, where it maps the components of vectors from + the final frame to the original frame, or as a physical rotation, + integrating the vectors into the original frame during the rotation + process. Consequently, the vector components are maintained in the original + frame’s perspective both before and after the rotation. + + Parameters + ---------- + input : Tensor + Each vector represents a vector in three-dimensional space. The number + of rotation matrices, number of translation vectors, and number of + input vectors must follow standard broadcasting rules: either one of + them equals unity or they both equal each other. + + transform : (Tensor, Tensor) + Transforms represented as a pair of rotation matrices and translation + vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and + the translations must have the shape $(\ldots, 3)$. + + inverse : bool, optional + If `True`, applies the inverse transformation (i.e., inverse rotation + and negated translation) to the input vectors. Default, `False`. + + Returns + ------- + output : Tensor + Rotated and translated vectors. + """ + rotation, translation = transform + + output = apply_rotation_matrix(input, rotation, inverse=inverse) + + if inverse: + output = output - translation + else: + output = output + translation + + return output diff --git a/src/beignet/_invert_rotation_matrix.py b/src/beignet/_invert_rotation_matrix.py new file mode 100644 index 0000000000..dbefb18fa6 --- /dev/null +++ b/src/beignet/_invert_rotation_matrix.py @@ -0,0 +1,19 @@ +import torch +from torch import Tensor + + +def invert_rotation_matrix(input: Tensor) -> Tensor: + r""" + Invert rotation matrices. + + Parameters + ---------- + input : Tensor, shape=(..., 3, 3) + Rotation matrices. + + Returns + ------- + output : Tensor, shape=(..., 3, 3) + Inverted rotation matrices. + """ + return torch.transpose(input, -2, -1) diff --git a/src/beignet/_invert_transform.py b/src/beignet/_invert_transform.py new file mode 100644 index 0000000000..27a91d1ce4 --- /dev/null +++ b/src/beignet/_invert_transform.py @@ -0,0 +1,31 @@ +import torch +from torch import Tensor + +import beignet.operators + + +def invert_transform(input: (Tensor, Tensor)) -> (Tensor, Tensor): + r""" + Invert transforms. + + Parameters + ---------- + input : (Tensor, Tensor) + Transforms represented as a pair of rotation matrices and translation + vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and + the translations must have the shape $(\ldots, 3)$. + + Returns + ------- + output : (Tensor, Tensor) + Inverted transforms represented as a pair of rotation matrices and + translation vectors. The rotation matrices have the shape + $(\ldots, 3, 3)$ and the translations have the shape $(\ldots, 3)$. + """ + rotation, translation = input + + rotation = beignet.operators.invert_rotation_matrix(rotation) + + return rotation, -rotation @ torch.squeeze( + torch.unsqueeze(translation, dim=-1), dim=-1 + ) diff --git a/src/beignet/constants/__init__.py b/src/beignet/constants/__init__.py new file mode 100644 index 0000000000..84e6c5f256 --- /dev/null +++ b/src/beignet/constants/__init__.py @@ -0,0 +1,17 @@ +from ._adjacent_residue_phi_cosine import ADJACENT_RESIDUE_PHI_COSINE +from ._adjacent_residue_psi_cosine import ADJACENT_RESIDUE_PSI_COSINE +from ._amino_acid_1 import AMINO_ACID_1 +from ._amino_acid_1_to_amino_acid_3 import AMINO_ACID_1_TO_AMINO_ACID_3 +from ._amino_acid_3 import AMINO_ACID_3 +from ._amino_acid_3_to_amino_acid_1 import AMINO_ACID_3_TO_AMINO_ACID_1 +from ._amino_acid_3_to_atom_14 import AMINO_ACID_3_TO_ATOM_14 + +__all__ = [ + "ADJACENT_RESIDUE_PHI_COSINE", + "ADJACENT_RESIDUE_PSI_COSINE", + "AMINO_ACID_1", + "AMINO_ACID_1_TO_AMINO_ACID_3", + "AMINO_ACID_3", + "AMINO_ACID_3_TO_AMINO_ACID_1", + "AMINO_ACID_3_TO_ATOM_14", +] diff --git a/src/beignet/constants/_adjacent_residue_phi_cosine.py b/src/beignet/constants/_adjacent_residue_phi_cosine.py new file mode 100644 index 0000000000..1ffa00f3be --- /dev/null +++ b/src/beignet/constants/_adjacent_residue_phi_cosine.py @@ -0,0 +1 @@ +ADJACENT_RESIDUE_PHI_COSINE = [-0.5203, 0.0353] diff --git a/src/beignet/constants/_adjacent_residue_psi_cosine.py b/src/beignet/constants/_adjacent_residue_psi_cosine.py new file mode 100644 index 0000000000..ff0215d2c4 --- /dev/null +++ b/src/beignet/constants/_adjacent_residue_psi_cosine.py @@ -0,0 +1 @@ +ADJACENT_RESIDUE_PSI_COSINE = [-0.4473, 0.0311] diff --git a/src/beignet/constants/_amino_acid_1.py b/src/beignet/constants/_amino_acid_1.py new file mode 100644 index 0000000000..48878eef02 --- /dev/null +++ b/src/beignet/constants/_amino_acid_1.py @@ -0,0 +1,22 @@ +AMINO_ACID_1 = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] diff --git a/src/beignet/constants/_amino_acid_1_to_amino_acid_3.py b/src/beignet/constants/_amino_acid_1_to_amino_acid_3.py new file mode 100644 index 0000000000..17c82163b6 --- /dev/null +++ b/src/beignet/constants/_amino_acid_1_to_amino_acid_3.py @@ -0,0 +1,22 @@ +AMINO_ACID_1_TO_AMINO_ACID_3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} diff --git a/src/beignet/constants/_amino_acid_3.py b/src/beignet/constants/_amino_acid_3.py new file mode 100644 index 0000000000..45b7de8f70 --- /dev/null +++ b/src/beignet/constants/_amino_acid_3.py @@ -0,0 +1,22 @@ +AMINO_ACID_3 = [ + "ALA", + "ARG", + "ASN", + "ASP", + "CYS", + "GLN", + "GLU", + "GLY", + "HIS", + "ILE", + "LEU", + "LYS", + "MET", + "PHE", + "PRO", + "SER", + "THR", + "TRP", + "TYR", + "VAL", +] diff --git a/src/beignet/constants/_amino_acid_3_to_amino_acid_1.py b/src/beignet/constants/_amino_acid_3_to_amino_acid_1.py new file mode 100644 index 0000000000..0e6c7ff51b --- /dev/null +++ b/src/beignet/constants/_amino_acid_3_to_amino_acid_1.py @@ -0,0 +1,22 @@ +AMINO_ACID_3_TO_AMINO_ACID_1 = { + "ALA": "A", + "ARG": "R", + "ASN": "N", + "ASP": "D", + "CYS": "C", + "GLN": "Q", + "GLU": "E", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LEU": "L", + "LYS": "K", + "MET": "M", + "PHE": "F", + "PRO": "P", + "SER": "S", + "THR": "T", + "TRP": "W", + "TYR": "Y", + "VAL": "V", +} diff --git a/src/beignet/constants/_amino_acid_3_to_atom_14.py b/src/beignet/constants/_amino_acid_3_to_atom_14.py new file mode 100644 index 0000000000..929d009c83 --- /dev/null +++ b/src/beignet/constants/_amino_acid_3_to_atom_14.py @@ -0,0 +1,23 @@ +AMINO_ACID_3_TO_ATOM_14 = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", "" ], + "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", "" ], + "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", "" ], + "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", "" ], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", "" ], + "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", "" ], + "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", "" ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", "" ], + "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", "" ], + "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", "" ], + "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", "" ], + "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", "" ], + "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", "" ], + "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", "" ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", "" ], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", "" ], + "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", "" ], + "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], + "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", "" ], + "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", "" ], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", "" ], +} diff --git a/src/beignet/nn/__init__.py b/src/beignet/nn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/beignet/nn/functional/__init__.py b/src/beignet/nn/functional/__init__.py new file mode 100644 index 0000000000..0e71fe47de --- /dev/null +++ b/src/beignet/nn/functional/__init__.py @@ -0,0 +1,7 @@ +from ._angle_norm_loss import angle_norm_loss +from ._distogram_loss import distogram_loss +from ._global_distance_test import global_distance_test +from ._per_residue_local_distance_difference_test import ( + per_residue_local_distance_difference_test, +) +from ._torsion_angle_loss import torsion_angle_loss diff --git a/src/beignet/nn/functional/_alphafold_loss.py b/src/beignet/nn/functional/_alphafold_loss.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/beignet/nn/functional/_angle_norm_loss.py b/src/beignet/nn/functional/_angle_norm_loss.py new file mode 100644 index 0000000000..39be677c78 --- /dev/null +++ b/src/beignet/nn/functional/_angle_norm_loss.py @@ -0,0 +1,25 @@ +from typing import Literal + +import torch +from torch import Tensor + + +def angle_norm_loss( + input, + reduction: Literal["mean", "sum"] | None = "mean", +) -> Tensor: + loss = torch.abs(torch.norm(input, dim=-1) - 1) + + match reduction: + case "mean": + loss = torch.mean( + loss, + dim=[-1, -2], + ) + case "sum": + loss = torch.sum( + loss, + dim=[-1, -2], + ) + + return 2.0 * loss diff --git a/src/beignet/nn/functional/_bond_angle_violation_loss.py b/src/beignet/nn/functional/_bond_angle_violation_loss.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/beignet/nn/functional/_bond_length_violation_loss.py b/src/beignet/nn/functional/_bond_length_violation_loss.py new file mode 100644 index 0000000000..e841bbe59c --- /dev/null +++ b/src/beignet/nn/functional/_bond_length_violation_loss.py @@ -0,0 +1,174 @@ +import torch +from torch import Tensor + +from beignet.constants import AMINO_ACID_3 + + +# IGNORE THIS: +def f4( + dx: Tensor, + je: Tensor, + s2: Tensor, + r0: Tensor, + zc: float = 12.0, + gj: float = 12.0, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor): + pn = torch.finfo(dx.dtype).eps + + me = {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])} + + cv = r0[..., 1:] == me["PRO"] + + ip = cv * 1.341 + p3 = cv * 0.016 + + n5 = ~cv + km = ~cv + + n5 = n5 * 1.329 + km = km * 0.014 + + n5 = n5 + ip + km = km + p3 + + ow = dx[..., :-1, 2, :] + ch = dx[..., +1:, 0, :] + fj = dx[..., :-1, 1, :] + n4 = dx[..., +1:, 1, :] + + k4 = je[..., :-1, 2] + uo = je[..., +1:, 0] + o6 = je[..., :-1, 1] + em = je[..., +1:, 1] + + r3 = fj - ow + re = ch - n4 + p6 = ow - ch + + r3 = r3**2 + re = re**2 + p6 = p6**2 + + r3 = torch.sum(r3, dim=-1) + re = torch.sum(re, dim=-1) + p6 = torch.sum(p6, dim=-1) + + r3 = r3 + pn + re = re + pn + p6 = p6 + pn + + r3 = torch.sqrt(r3) + re = torch.sqrt(re) + p6 = torch.sqrt(p6) + + mp = n5 + mp = p6 - mp + mp = mp**2 + mp = mp + pn + mp = torch.sqrt(mp) + + zu = ch - ow + zu = zu / p6[..., None] + + tn = n4 - ch + mn = fj - ow + + mn = mn / r3[..., None] + tn = tn / re[..., None] + + mn = mn * +zu + tn = tn * -zu + + mn = torch.sum(mn, dim=-1) + tn = torch.sum(tn, dim=-1) + + xa = mn + 0.4473 + xn = tn + 0.5203 + + xa = xa**2 + xn = xn**2 + + xa = xa + pn + xn = xn + pn + + xa = torch.sqrt(xa) + xn = torch.sqrt(xn) + + e1 = zc * km + f6 = zc * 0.0140 + ot = zc * 0.0353 + + e1 = mp - e1 + f6 = xa - f6 + ot = xn - ot + + e1 = torch.nn.functional.relu(e1) + f6 = torch.nn.functional.relu(f6) + ot = torch.nn.functional.relu(ot) + + qh = s2[..., :-1] == 1.0 + qh = s2[..., 1:] - qh + + x8 = k4 * uo + x8 = x8 * qh + + qg = o6 * k4 + qg = qg * uo + qg = qg * qh + + rc = k4 * uo + rc = rc * em + rc = rc * qh + + t4 = e1 * x8 + vc = f6 * qg + wx = ot * rc + + t4 = torch.sum(t4, dim=-1) + vc = torch.sum(vc, dim=-1) + wx = torch.sum(wx, dim=-1) + + ze = torch.sum(x8, dim=-1) + wf = torch.sum(qg, dim=-1) + qf = torch.sum(rc, dim=-1) + + ze = ze + pn + wf = wf + pn + qf = qf + pn + + t4 = t4 / ze + vc = vc / wf + wx = wx / qf + + ub = ot + e1 + ub = ub + f6 + + rv = torch.nn.functional.pad(ub, [0, 1]) + dk = torch.nn.functional.pad(ub, [1, 0]) + + nq = rv + dk + nq = nq * 0.5 + + ic = k4 * uo + ic = ic * em + ic = ic * qh + + x3 = [] + + for zg in [mp, xn, xa]: + mw = gj * 0.0353 + mw = zg > mw + mw = mw * ic + + x3 = [*x3, mw] + + mh = torch.stack(x3, dim=-2) + + va, _ = torch.max(mh, dim=-2) + + yn = torch.nn.functional.pad(va, [0, 1]) + gh = torch.nn.functional.pad(va, [1, 0]) + + nd = torch.maximum(yn, gh) + + return t4, vc, wx, nq, nd diff --git a/src/beignet/nn/functional/_clash_loss.py b/src/beignet/nn/functional/_clash_loss.py new file mode 100644 index 0000000000..579ebfedf5 --- /dev/null +++ b/src/beignet/nn/functional/_clash_loss.py @@ -0,0 +1,82 @@ +import torch +from torch import Tensor + + +def clash_loss( + input: Tensor, + target: (Tensor, Tensor), + mask: Tensor, + tighten=0.0, + epsilon=1e-10, +) -> (Tensor, Tensor, Tensor): + r""" + A one-sided flat-bottom-potential, that penalizes steric clashes: + + $$\mathcal{L}_{\text{clash}}=\sum_{i=1}^{N_{\text{non-bonded}}}\max{ + \left(\text{distance }_{\text{Van der Waals radii}}^{i}- + \tau- + \text{distance }_{\text{predicted}}^{i},0\right)},$$ + + where $N_{\text{non-bonded pairs}}$ is the number of all non-bonded atom + pairs, $\text{distance }_{\text{predicted}}^{i}$ is the distance of two + non-bonded atoms in the predicted structure, and + $\text{distance }_{\text{Van der Waals radii}}^{i}$ is the “clashing + distance” of two non-bonded atoms according to their Van der Waals radii. + The tolerance, $\tau$, $1.5\text{\r{A}}$. + + Parameters + ---------- + input : Tensor, shape=(..., N, 14, 3) + Predicted positions of atoms in global prediction frame. + + target : Tensor, shape=(..., N, 14), Tensor, shape=(..., N, 14) + Lower and upper bound on allowed distances. + + mask : Tensor, shape=(..., N, 14) + Mask denoting whether atom at positions exists for given amino acid type. + + tighten : float, optional + Extra factor to tighten loss. Default, 0.0. + + epsilon : float, optional + Small value to avoid division by zero. Default, 1e-10. + + Returns + ------- + output : Tensor, shape=(..., N, 14) + Sum of all clash losses per atom. + + mask : Tensor, shape=(..., N, 14) + Whether atom clashes with any other atom. + + clashes : Tensor, shape=(..., N) + Number of clashes per atom. + """ + distance_mask = torch.eye(14) + distance_mask = distance_mask[None] + distance_mask = 1.0 - distance_mask + shape = [*((1,) * len(mask.shape[:-2])), *distance_mask.shape] + distance_mask = torch.reshape(distance_mask, shape) + distance_mask = distance_mask * mask[..., :, :, None] + distance_mask = distance_mask * mask[..., :, None, :] + + distance = input[..., :, :, None, :] - input[..., :, None, :, :] + distance = torch.sqrt(torch.sum(distance**2, dim=-1) + epsilon) + + a, b = target + + a = torch.nn.functional.relu((a + tighten) - distance) + b = torch.nn.functional.relu(distance - (b - tighten)) + + loss = (a + b) * distance_mask + + violations = ((distance < a) | (distance > b)) * distance_mask + + return ( + torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1), + torch.maximum( + torch.max(violations, dim=-2)[0], + torch.max(violations, dim=-1)[0], + ), + torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1), + ) diff --git a/src/beignet/nn/functional/_distogram_loss.py b/src/beignet/nn/functional/_distogram_loss.py new file mode 100644 index 0000000000..dd2e72dd29 --- /dev/null +++ b/src/beignet/nn/functional/_distogram_loss.py @@ -0,0 +1,22 @@ +# ruff: noqa: E501 + +from typing import Literal + +import torch +from torch import Tensor + + +def distogram_loss(input: Tensor, target: Tensor, mask: Tensor, start: float = 2.3125, end: float = 21.6875, steps: int = 64, reduction: Literal["mean", "sum"] | None = "mean") -> Tensor: # fmt: off + target = torch.nn.functional.one_hot(torch.sum(torch.sum((target[..., None, :] - target[..., None, :, :]) ** 2.0, dim=-1, keepdim=True) > torch.linspace(start, end, steps - 1) ** 2.0, dim=-1), steps) # fmt: off + + mask = mask[..., None] * mask[..., None, :] + + output = torch.sum(torch.sum(torch.sum(torch.nn.functional.log_softmax(input, dim=-1) * target, dim=-1) * -1.0 * mask, dim=-1) / (torch.sum(mask, dim=[-1, -2]) + torch.finfo(mask.dtype).eps)[..., None], dim=-1) # fmt: off + + match reduction: + case "mean": + output = torch.mean(output) + case "sum": + output = torch.sum(output) + + return output diff --git a/src/beignet/nn/functional/_experimentally_resolved_loss.py b/src/beignet/nn/functional/_experimentally_resolved_loss.py new file mode 100644 index 0000000000..bfd4e432a9 --- /dev/null +++ b/src/beignet/nn/functional/_experimentally_resolved_loss.py @@ -0,0 +1,61 @@ +import torch +from torch import Tensor + + +def sigmoid_cross_entropy(input, target): + logits_dtype = input.dtype + input = input.double() + target = target.double() + log_p = torch.nn.functional.logsigmoid(input) + log_not_p = torch.nn.functional.logsigmoid(-1 * input) + output = (-1.0 * target) * log_p - (1.0 - target) * log_not_p + output = output.to(dtype=logits_dtype) + return output + + +def experimentally_resolved_loss( + input: Tensor, + atom37_atom_exists: Tensor, + all_atom_mask: Tensor, + resolution: Tensor, + minimum_resolution: float, + maximum_resolution: float, + eps: float = 1e-8, +) -> Tensor: + r""" + The model contains a head that predicts if an atom is experimentally + resolved in a high-resolution structure. The input for this head is the + single representation $\{\mathrm{s}_{i}\}$ produced by the Evoformer stack. + The single representation is projected with a linear layer and a sigmoid to + atom-wise probabilities $\{p^{\mathrm{experimentally\;resolved,\;a}}_{i}\}$ + with $i\in\left[1,\ldots,N_{\textrm{residue}}\right]$ and + $a\in\textsl{S}_{\mathrm{amino\;acids}}$. + + $$\mathcal{L}_{\mathrm{experimentally\;resolved}}=\mathrm{mean}_{\left(i,a\right)}\left(-y_{i}^{a}\log{\left(p_{i}^{\mathrm{experimentally\;resolved},a}\right)}-\left(1-y_{i}^{a}\right)\log{\left(1-p_{i}^{\mathrm{experimentally\;resolved},a}\right)}\right)$$ + + where $y_{i}^{a}\in\left\{0,1\right\}$ is the target (i.e., if atom $a$ in + residue $i$ was resolved in the experiment). + """ + epsilon = torch.finfo(input.dtype).eps + errors = sigmoid_cross_entropy(input, all_atom_mask) + + output = errors * atom37_atom_exists + output = torch.sum(output, dim=-1) + b = torch.sum(atom37_atom_exists, dim=[-1, -2]) + b = b.unsqueeze(-1) + b = b + epsilon + + output = output / b + + output = torch.sum(output, dim=-1) + + x = resolution >= minimum_resolution + y = resolution <= maximum_resolution + + a = x & y + + output = output * a + + output = torch.mean(output) + + return output diff --git a/src/beignet/nn/functional/_frame_aligned_point_error.py b/src/beignet/nn/functional/_frame_aligned_point_error.py new file mode 100644 index 0000000000..6cddfa9a06 --- /dev/null +++ b/src/beignet/nn/functional/_frame_aligned_point_error.py @@ -0,0 +1,119 @@ +from typing import Tuple + +import torch +from torch import Tensor + +import beignet + + +def frame_aligned_point_error( + input: Tuple[Tuple[Tensor, Tensor], Tensor], + target: Tuple[Tuple[Tensor, Tensor], Tensor], + mask: Tuple[Tensor, Tensor], + z: float, + mp: Tensor | None = None, + wk: float | None = None, +) -> Tensor: + r""" + Score a set of predicted atom coordinates, $\left\{\vec{x}_{j}\right\}$, + under a set of predicted local frames, $\left\{T_{i}\right\}$, against the + corresponding target atom coordinates, + $\left\{\vec{x}_{i}^{\mathrm{True}}\right\}$, and target local frames, + $\left\{T_{i}^{\mathrm{True}}\right\}$. All atoms in all backbone and side + chain frames are scored. + + Additionally, a cheaper version (scoring only all $C_\alpha$ atoms in all + backbone frames) is used as an auxiliary loss in every layer of the + AlphaFold structure module. + + In order to formulate the loss the atom position $\vec{x}_{j}$ is computed + relative to frame $T_{i}$ and the location of the corresponding true atom + position $\vec{x}_{j}^{\mathrm{True}}$ relative to the true frame + $T_{i}^{\mathrm{True}}$. The deviation is computed as a robust L2 norm + ($\epsilon$ is a small constant added to ensure that gradients are + numerically well behaved for small differences. The exact value of this + constant does not matter, as long as it is small enough. + + The $N_{\mathrm{frames}} \times N_{\mathrm{atoms}}$ deviations are + penalized with a clamped L1 loss with a length scale, $Z = 10\text{\r{A}}$, + to make the loss unitless. + + Parameters + ---------- + input : Tensor, (Tensor, Tensor) + A pair of predicted atom coordinates, $\left\{\vec{x}_{j}\right\}$, and + predicted local frames, $\left\{T_{i}\right\}$. A frame is represented + as a pair of rotation matrices and corresponding translations. The + predicted atom positions must have the shape + $(\\ldots, \text{points}, 3)$, the rotation matrices must have the + shape $(\\ldots, 3, 3)$, and the translations must have the shape + $(\\ldots, 3)$. + + target : Tensor, (Tensor, Tensor) + A pair of target atom coordinates, + $\left\{\vec{x}_{i}^{\mathrm{True}}\right\}$, and target local frames, + $\left\{T_{i}^{\mathrm{True}}\right\}$. A frame is represented as a + pair of rotation matrices and corresponding translations. The predicted + atom positions must have the shape $(\\ldots, \text{points}, 3)$, the + rotation matrices must have the shape $(\\ldots, 3, 3)$, and the + translations must have the shape $(\\ldots, 3)$. + + mask : (Tensor, Tensor) + [*, N_frames] binary mask for the frames + [..., points], position masks + + z : float + Length scale by which the loss is divided + + mp : Tensor | None, optional + [*, N_frames, N_pts] mask to use for separating intra-chain from + inter-chain losses. + + wk : float | None, optional + Cutoff above which distance errors are disregarded + + Returns + ------- + output : Tensor + Losses for each frame of shape $(\\ldots, 3)$. + """ + transform, input = input + + target_transform, target = target + + epsilon = torch.finfo(input.dtype).eps + + input = beignet.apply_transform( + input[..., None, :, :], + beignet.invert_transform( + transform, + ), + ) + + target = beignet.apply_transform( + target[..., None, :, :], + beignet.invert_transform( + target_transform, + ), + ) + + output = torch.sqrt(torch.sum((input - target) ** 2, dim=-1) + epsilon) + + if wk is not None: + output = torch.clamp(output, 0, wk) + + output = output / z * mask[0][..., None] * mask[1][..., None, :] + + if mp is not None: + output = torch.sum(output * mp, dim=[-1, -2]) / ( + torch.sum(mask[0][..., None] * mask[1][..., None, :] * mp, dim=[-2, -1]) + + epsilon + ) + else: + output = torch.sum( + torch.sum(output, dim=-1) + / (torch.sum(mask[0], dim=-1)[..., None] + epsilon), + dim=-1, + ) / (torch.sum(mask[1], dim=-1) + epsilon) + + return output diff --git a/src/beignet/nn/functional/_global_distance_test.py b/src/beignet/nn/functional/_global_distance_test.py new file mode 100644 index 0000000000..7f59cd4aac --- /dev/null +++ b/src/beignet/nn/functional/_global_distance_test.py @@ -0,0 +1,33 @@ +from typing import Sequence + +import torch +from torch import Tensor + + +def global_distance_test( + input: Tensor, + other: Tensor, + mask: Tensor, + cutoffs: Sequence[float], +) -> Tensor: + n = torch.sum(mask, dim=-1) + + y = input - other + y = y**2 + y = torch.sum(y, dim=-1) + y = torch.sqrt(y) + + scores = torch.zeros(len(cutoffs)) + + for index, cutoff in enumerate(cutoffs): + scores[index] = torch.mean(torch.sum((y <= cutoff) * mask, dim=-1) / n) + + return sum(scores) / len(scores) + + +def global_distance_test_ts(p1, p2, mask): + return global_distance_test(p1, p2, mask, [1.0, 2.0, 4.0, 8.0]) + + +def global_distance_test_ha(p1, p2, mask): + return global_distance_test(p1, p2, mask, [0.5, 1.0, 2.0, 4.0]) diff --git a/src/beignet/nn/functional/_masked_multiple_sequence_alignment_loss.py b/src/beignet/nn/functional/_masked_multiple_sequence_alignment_loss.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/beignet/nn/functional/_per_residue_local_distance_difference_test.py b/src/beignet/nn/functional/_per_residue_local_distance_difference_test.py new file mode 100644 index 0000000000..4b70521293 --- /dev/null +++ b/src/beignet/nn/functional/_per_residue_local_distance_difference_test.py @@ -0,0 +1,24 @@ +# ruff: noqa: E501 + +import torch +import torch.nn.functional +from torch import Tensor + + +def per_residue_local_distance_difference_test(input: Tensor) -> Tensor: + """ + Parameters + ---------- + input : Tensor + + Returns + ------- + output : Tensor + """ + output = torch.nn.functional.softmax(input, dim=-1) + + step = 1.0 / input.shape[-1] + + bounds = torch.arange(0.5 * step, 1.0, step) + + return torch.sum(output * torch.reshape(bounds, [*[1 for _ in range(len(output.shape[:-1]))], *bounds.shape]), dim=-1) * 100.0 # fmt: off diff --git a/src/beignet/nn/functional/_template_modeling_score.py b/src/beignet/nn/functional/_template_modeling_score.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/beignet/nn/functional/_torsion_angle_loss.py b/src/beignet/nn/functional/_torsion_angle_loss.py new file mode 100644 index 0000000000..310d4f6b04 --- /dev/null +++ b/src/beignet/nn/functional/_torsion_angle_loss.py @@ -0,0 +1,47 @@ +from typing import Literal + +import torch +from torch import Tensor + +from beignet.nn.functional import angle_norm_loss + + +def torsion_angle_loss( + input: Tensor, + target: (Tensor, Tensor), + reduction: Literal["mean", "sum"] | None = "mean", +) -> Tensor: + """ + Parameters + ---------- + input : Tensor + [*, N, 7, 2] + + target : (Tensor, Tensor) + [*, N, 7, 2], [*, N, 7, 2] + + reduction : str + "mean" or "sum" + """ + x = input / torch.unsqueeze(torch.norm(input, dim=-1), dim=-1) + + y, z = target + + loss = torch.minimum( + torch.norm(x - y, dim=-1) ** 2, + torch.norm(x - z, dim=-1) ** 2, + ) + + match reduction: + case "mean": + loss = torch.mean( + loss, + dim=[-1, -2], + ) + case "sum": + loss = torch.sum( + loss, + dim=[-1, -2], + ) + + return loss + angle_norm_loss(x, reduction) diff --git a/src/beignet/nn/functional/_violation_loss.py b/src/beignet/nn/functional/_violation_loss.py new file mode 100644 index 0000000000..e69de29bb2