Skip to content

Commit

Permalink
AlphaFold losses
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 30, 2024
1 parent 724e191 commit 61e8ef3
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/beignet/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from ._angle_norm_loss import angle_norm_loss
from ._distogram_loss import distogram_loss
from ._per_residue_local_distance_difference_test import (
per_residue_local_distance_difference_test,
)
from ._torsion_angle_loss import torsion_angle_loss
22 changes: 22 additions & 0 deletions src/beignet/nn/functional/_distogram_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ruff: noqa: E501

from typing import Literal

import torch
from torch import Tensor


def distogram_loss(input: Tensor, target: Tensor, mask: Tensor, start: float = 2.3125, end: float = 21.6875, steps: int = 64, reduction: Literal["mean", "sum"] | None = "mean") -> Tensor: # fmt: off
target = torch.nn.functional.one_hot(torch.sum(torch.sum((target[..., None, :] - target[..., None, :, :]) ** 2.0, dim=-1, keepdim=True) > torch.linspace(start, end, steps - 1) ** 2.0, dim=-1), steps) # fmt: off

mask = mask[..., None] * mask[..., None, :]

output = torch.sum(torch.sum(torch.sum(torch.nn.functional.log_softmax(input, dim=-1) * target, dim=-1) * -1.0 * mask, dim=-1) / (torch.sum(mask, dim=[-1, -2]) + torch.finfo(mask.dtype).eps)[..., None], dim=-1) # fmt: off

match reduction:
case "mean":
output = torch.mean(output)
case "sum":
output = torch.sum(output)

return output
66 changes: 66 additions & 0 deletions src/beignet/nn/functional/_frame_aligned_point_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ruff: noqa: E501

import torch
from torch import Tensor

import beignet.operators


def frame_aligned_point_error(
input: (Tensor, Tensor, Tensor),
target: (Tensor, Tensor, Tensor),
mask: (Tensor, Tensor),
length_scale: float,
pair_mask: Tensor | None = None,
maximum: float | None = None,
epsilon=1e-8,
) -> Tensor:
"""
Parameters
----------
input : (Tensor, Tensor, Tensor)
A $3$-tuple of rotation matrices, translations, and positions. The
rotation matrices must have the shape $(\\ldots, 3, 3)$, the
translations must have the shape $(\\ldots, 3)$, and the positions must
have the shape $(\\ldots, \text{points}, 3)$.
target : (Tensor, Tensor, Tensor)
A $3$-tuple of target rotation matrices, translations, and positions.
The rotation matrices must have the shape $(\\ldots, 3, 3)$, the
translations must have the shape $(\\ldots, 3)$, and the positions must
have the shape $(\\ldots, \text{points}, 3)$.
mask : (Tensor, Tensor)
[*, N_frames] binary mask for the frames
[..., points], position masks
length_scale : float
Length scale by which the loss is divided
pair_mask : Tensor | None, optional
[*, N_frames, N_pts] mask to use for separating intra- from inter-chain losses.
maximum : float | None, optional
Cutoff above which distance errors are disregarded
epsilon : float, optional
Small value used to regularize denominators
Returns
-------
output : Tensor
Losses for each frame of shape $(\\ldots, 3)$.
"""
output = torch.sqrt(torch.sum((beignet.operators.apply_transform(input[1][..., None, :, :], beignet.operators.invert_transform(input[0])) - beignet.operators.apply_transform(target[1][..., None, :, :], beignet.operators.invert_transform(target[0]), )) ** 2, dim=-1) + epsilon) # fmt: off

if maximum is not None:
output = torch.clamp(output, 0, maximum)

output = output / length_scale * mask[0][..., None] * mask[1][..., None, :] # fmt: off

if pair_mask is not None:
output = torch.sum(output * pair_mask, dim=[-1, -2]) / (torch.sum(mask[0][..., None] * mask[1][..., None, :] * pair_mask, dim=[-2, -1]) + epsilon) # fmt: off
else:
output = torch.sum((torch.sum(output, dim=-1) / (torch.sum(mask[0], dim=-1))[..., None] + epsilon), dim=-1) / (torch.sum(mask[1], dim=-1) + epsilon) # fmt: off

return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn.functional
from torch import Tensor


def per_residue_local_distance_difference_test(input: Tensor) -> Tensor:
"""
Parameters
----------
input : Tensor
Returns
-------
output : Tensor
"""
probs = torch.nn.functional.softmax(input, dim=-1)

bins = input.shape[-1]

step = 1.0 / bins

bounds = torch.arange(0.5 * step, 1.0, step)

indexes = (1,) * len(probs.shape[:-1])
output = bounds.view(*indexes, *bounds.shape)
output = probs * output
output = torch.sum(output, dim=-1)

return output * 100
4 changes: 4 additions & 0 deletions src/beignet/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._apply_rotation_matrix import apply_rotation_matrix
from ._apply_transform import apply_transform
from ._invert_rotation_matrix import invert_rotation_matrix
from ._invert_transform import invert_transform
47 changes: 47 additions & 0 deletions src/beignet/operators/_apply_rotation_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch import Tensor


def apply_rotation_matrix(
input: Tensor,
rotation: Tensor,
inverse: bool | None = False,
) -> Tensor:
r"""
Rotates vectors in three-dimensional space using rotation matrices.
Note
----
This function interprets the rotation of the original frame to the final
frame as either a projection, where it maps the components of vectors from
the final frame to the original frame, or as a physical rotation,
integrating the vectors into the original frame during the rotation
process. Consequently, the vector components are maintained in the original
frame’s perspective both before and after the rotation.
Parameters
----------
input : Tensor, shape (..., 3)
Each vector represents a vector in three-dimensional space. The number
of rotation matrices and number of vectors must follow standard
broadcasting rules: either one of them equals unity or they both equal
each other.
rotation : Tensor, shape (..., 3, 3)
Rotation matrices.
inverse : bool, optional
If `True` the inverse of the rotation matrices are applied to the input
vectors. Default, `False`.
Returns
-------
rotated_vectors : Tensor, shape (..., 3)
Rotated vectors.
"""
if inverse:
output = torch.einsum("ikj, ik -> ij", rotation, input)
else:
output = torch.einsum("ijk, ik -> ij", rotation, input)

return output
54 changes: 54 additions & 0 deletions src/beignet/operators/_apply_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from torch import Tensor

from beignet.operators import apply_rotation_matrix


def apply_transform(
input: Tensor,
transform: (Tensor, Tensor),
inverse: bool | None = False,
) -> Tensor:
r"""
Applies three-dimensional transforms to vectors.
Note
----
This function interprets the rotation of the original frame to the final
frame as either a projection, where it maps the components of vectors from
the final frame to the original frame, or as a physical rotation,
integrating the vectors into the original frame during the rotation
process. Consequently, the vector components are maintained in the original
frame’s perspective both before and after the rotation.
Parameters
----------
input : Tensor
Each vector represents a vector in three-dimensional space. The number
of rotation matrices, number of translation vectors, and number of
input vectors must follow standard broadcasting rules: either one of
them equals unity or they both equal each other.
transform : (Tensor, Tensor)
Transforms represented as a pair of rotation matrices and translation
vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and
the translations must have the shape $(\ldots, 3)$.
inverse : bool, optional
If `True`, applies the inverse transformation (i.e., inverse rotation
and negated translation) to the input vectors. Default, `False`.
Returns
-------
output : Tensor
Rotated and translated vectors.
"""
rotation, translation = transform

output = apply_rotation_matrix(input, rotation, inverse=inverse)

if inverse:
output = output - translation
else:
output = output + translation

return output
19 changes: 19 additions & 0 deletions src/beignet/operators/_invert_rotation_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from torch import Tensor


def invert_rotation_matrix(input: Tensor) -> Tensor:
r"""
Invert rotation matrices.
Parameters
----------
input : Tensor, shape=(..., 3, 3)
Rotation matrices.
Returns
-------
output : Tensor, shape=(..., 3, 3)
Inverted rotation matrices.
"""
return torch.transpose(input, -2, -1)
31 changes: 31 additions & 0 deletions src/beignet/operators/_invert_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from torch import Tensor

import beignet.operators


def invert_transform(input: (Tensor, Tensor)) -> (Tensor, Tensor):
r"""
Invert transforms.
Parameters
----------
input : (Tensor, Tensor)
Transforms represented as a pair of rotation matrices and translation
vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and
the translations must have the shape $(\ldots, 3)$.
Returns
-------
output : (Tensor, Tensor)
Inverted transforms represented as a pair of rotation matrices and
translation vectors. The rotation matrices have the shape
$(\ldots, 3, 3)$ and the translations have the shape $(\ldots, 3)$.
"""
rotation, translation = input

rotation = beignet.operators.invert_rotation_matrix(rotation)

return rotation, -rotation @ torch.squeeze(
torch.unsqueeze(translation, dim=-1), dim=-1
)

0 comments on commit 61e8ef3

Please sign in to comment.