Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlphaFold losses #15

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- hooks:
- id: "check-toml"
- id: "check-yaml"
repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: "v4.6.0"
- hooks:
- args:
- "--fix"
id: "ruff"
- id: "ruff-format"
repo: "https://github.com/astral-sh/ruff-pre-commit"
rev: "v0.3.7"
26 changes: 24 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,30 @@ name = "beignet"
readme = "README.md"
requires-python = ">=3.10"

[tool.ruff.format]
docstring-code-format = true
[tool.ruff]
exclude = [
"./src/beignet/constants/_adjacent_residue_phi_cosine.py",
"./src/beignet/constants/_adjacent_residue_psi_cosine.py",
"./src/beignet/constants/_amino_acid_1.py",
"./src/beignet/constants/_amino_acid_1_to_amino_acid_3.py",
"./src/beignet/constants/_amino_acid_3.py",
"./src/beignet/constants/_amino_acid_3_to_amino_acid_1.py",
"./src/beignet/constants/_amino_acid_3_to_atom_14.py",
]

[tool.ruff.lint]
select = [
"B", # FLAKE8-BUGBEAR
"E", # PYCODESTYLE ERRORS
"F", # PYFLAKES
"I", # ISORT
"W", # PYCODESTYLE WARNINGS
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = [
"F401", # MODULE IMPORTED BUT UNUSED
]

[tool.setuptools_scm]
local_scheme = "no-local-version"
12 changes: 12 additions & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@
__version__ = importlib.metadata.version("beignet")
except PackageNotFoundError:
__version__ = None

from ._apply_rotation_matrix import apply_rotation_matrix
from ._apply_transform import apply_transform
from ._invert_rotation_matrix import invert_rotation_matrix
from ._invert_transform import invert_transform

__all__ = [
"apply_rotation_matrix",
"apply_transform",
"invert_rotation_matrix",
"invert_transform",
]
47 changes: 47 additions & 0 deletions src/beignet/_apply_rotation_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch import Tensor


def apply_rotation_matrix(
input: Tensor,
rotation: Tensor,
inverse: bool | None = False,
) -> Tensor:
r"""
Rotates vectors in three-dimensional space using rotation matrices.

Note
----
This function interprets the rotation of the original frame to the final
frame as either a projection, where it maps the components of vectors from
the final frame to the original frame, or as a physical rotation,
integrating the vectors into the original frame during the rotation
process. Consequently, the vector components are maintained in the original
frame’s perspective both before and after the rotation.

Parameters
----------
input : Tensor, shape (..., 3)
Each vector represents a vector in three-dimensional space. The number
of rotation matrices and number of vectors must follow standard
broadcasting rules: either one of them equals unity or they both equal
each other.

rotation : Tensor, shape (..., 3, 3)
Rotation matrices.

inverse : bool, optional
If `True` the inverse of the rotation matrices are applied to the input
vectors. Default, `False`.

Returns
-------
rotated_vectors : Tensor, shape (..., 3)
Rotated vectors.
"""
if inverse:
output = torch.einsum("ikj, ik -> ij", rotation, input)
else:
output = torch.einsum("ijk, ik -> ij", rotation, input)

return output
54 changes: 54 additions & 0 deletions src/beignet/_apply_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from torch import Tensor

from beignet.operators import apply_rotation_matrix


def apply_transform(
input: Tensor,
transform: (Tensor, Tensor),
inverse: bool | None = False,
) -> Tensor:
r"""
Applies three-dimensional transforms to vectors.

Note
----
This function interprets the rotation of the original frame to the final
frame as either a projection, where it maps the components of vectors from
the final frame to the original frame, or as a physical rotation,
integrating the vectors into the original frame during the rotation
process. Consequently, the vector components are maintained in the original
frame’s perspective both before and after the rotation.

Parameters
----------
input : Tensor
Each vector represents a vector in three-dimensional space. The number
of rotation matrices, number of translation vectors, and number of
input vectors must follow standard broadcasting rules: either one of
them equals unity or they both equal each other.

transform : (Tensor, Tensor)
Transforms represented as a pair of rotation matrices and translation
vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and
the translations must have the shape $(\ldots, 3)$.

inverse : bool, optional
If `True`, applies the inverse transformation (i.e., inverse rotation
and negated translation) to the input vectors. Default, `False`.

Returns
-------
output : Tensor
Rotated and translated vectors.
"""
rotation, translation = transform

output = apply_rotation_matrix(input, rotation, inverse=inverse)

if inverse:
output = output - translation
else:
output = output + translation

return output
19 changes: 19 additions & 0 deletions src/beignet/_invert_rotation_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from torch import Tensor


def invert_rotation_matrix(input: Tensor) -> Tensor:
r"""
Invert rotation matrices.

Parameters
----------
input : Tensor, shape=(..., 3, 3)
Rotation matrices.

Returns
-------
output : Tensor, shape=(..., 3, 3)
Inverted rotation matrices.
"""
return torch.transpose(input, -2, -1)
31 changes: 31 additions & 0 deletions src/beignet/_invert_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from torch import Tensor

import beignet.operators


def invert_transform(input: (Tensor, Tensor)) -> (Tensor, Tensor):
r"""
Invert transforms.

Parameters
----------
input : (Tensor, Tensor)
Transforms represented as a pair of rotation matrices and translation
vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and
the translations must have the shape $(\ldots, 3)$.

Returns
-------
output : (Tensor, Tensor)
Inverted transforms represented as a pair of rotation matrices and
translation vectors. The rotation matrices have the shape
$(\ldots, 3, 3)$ and the translations have the shape $(\ldots, 3)$.
"""
rotation, translation = input

rotation = beignet.operators.invert_rotation_matrix(rotation)

return rotation, -rotation @ torch.squeeze(
torch.unsqueeze(translation, dim=-1), dim=-1
)
17 changes: 17 additions & 0 deletions src/beignet/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ._adjacent_residue_phi_cosine import ADJACENT_RESIDUE_PHI_COSINE
from ._adjacent_residue_psi_cosine import ADJACENT_RESIDUE_PSI_COSINE
from ._amino_acid_1 import AMINO_ACID_1
from ._amino_acid_1_to_amino_acid_3 import AMINO_ACID_1_TO_AMINO_ACID_3
from ._amino_acid_3 import AMINO_ACID_3
from ._amino_acid_3_to_amino_acid_1 import AMINO_ACID_3_TO_AMINO_ACID_1
from ._amino_acid_3_to_atom_14 import AMINO_ACID_3_TO_ATOM_14

__all__ = [
"ADJACENT_RESIDUE_PHI_COSINE",
"ADJACENT_RESIDUE_PSI_COSINE",
"AMINO_ACID_1",
"AMINO_ACID_1_TO_AMINO_ACID_3",
"AMINO_ACID_3",
"AMINO_ACID_3_TO_AMINO_ACID_1",
"AMINO_ACID_3_TO_ATOM_14",
]
1 change: 1 addition & 0 deletions src/beignet/constants/_adjacent_residue_phi_cosine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ADJACENT_RESIDUE_PHI_COSINE = [-0.5203, 0.0353]
1 change: 1 addition & 0 deletions src/beignet/constants/_adjacent_residue_psi_cosine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ADJACENT_RESIDUE_PSI_COSINE = [-0.4473, 0.0311]
22 changes: 22 additions & 0 deletions src/beignet/constants/_amino_acid_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
AMINO_ACID_1 = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
22 changes: 22 additions & 0 deletions src/beignet/constants/_amino_acid_1_to_amino_acid_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
AMINO_ACID_1_TO_AMINO_ACID_3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
22 changes: 22 additions & 0 deletions src/beignet/constants/_amino_acid_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
AMINO_ACID_3 = [
"ALA",
"ARG",
"ASN",
"ASP",
"CYS",
"GLN",
"GLU",
"GLY",
"HIS",
"ILE",
"LEU",
"LYS",
"MET",
"PHE",
"PRO",
"SER",
"THR",
"TRP",
"TYR",
"VAL",
]
22 changes: 22 additions & 0 deletions src/beignet/constants/_amino_acid_3_to_amino_acid_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
AMINO_ACID_3_TO_AMINO_ACID_1 = {
"ALA": "A",
"ARG": "R",
"ASN": "N",
"ASP": "D",
"CYS": "C",
"GLN": "Q",
"GLU": "E",
"GLY": "G",
"HIS": "H",
"ILE": "I",
"LEU": "L",
"LYS": "K",
"MET": "M",
"PHE": "F",
"PRO": "P",
"SER": "S",
"THR": "T",
"TRP": "W",
"TYR": "Y",
"VAL": "V",
}
23 changes: 23 additions & 0 deletions src/beignet/constants/_amino_acid_3_to_atom_14.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
AMINO_ACID_3_TO_ATOM_14 = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", "" ],
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", "" ],
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", "" ],
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", "" ],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", "" ],
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", "" ],
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", "" ],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", "" ],
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", "" ],
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", "" ],
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", "" ],
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", "" ],
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", "" ],
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", "" ],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", "" ],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", "" ],
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", "" ],
"TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"],
"TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", "" ],
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", "" ],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", "" ],
}
Empty file added src/beignet/nn/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions src/beignet/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._angle_norm_loss import angle_norm_loss
from ._distogram_loss import distogram_loss
from ._global_distance_test import global_distance_test
from ._per_residue_local_distance_difference_test import (
per_residue_local_distance_difference_test,
)
from ._torsion_angle_loss import torsion_angle_loss
Empty file.
25 changes: 25 additions & 0 deletions src/beignet/nn/functional/_angle_norm_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Literal

import torch
from torch import Tensor


def angle_norm_loss(
input,
reduction: Literal["mean", "sum"] | None = "mean",
) -> Tensor:
loss = torch.abs(torch.norm(input, dim=-1) - 1)

match reduction:
case "mean":
loss = torch.mean(
loss,
dim=[-1, -2],
)
case "sum":
loss = torch.sum(
loss,
dim=[-1, -2],
)

return 2.0 * loss
Empty file.
Loading
Loading