-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
from ._angle_norm_loss import angle_norm_loss | ||
from ._distogram_loss import distogram_loss | ||
from ._per_residue_local_distance_difference_test import ( | ||
per_residue_local_distance_difference_test, | ||
) | ||
from ._torsion_angle_loss import torsion_angle_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# ruff: noqa: E501 | ||
|
||
from typing import Literal | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def distogram_loss(input: Tensor, target: Tensor, mask: Tensor, start: float = 2.3125, end: float = 21.6875, steps: int = 64, reduction: Literal["mean", "sum"] | None = "mean") -> Tensor: # fmt: off | ||
target = torch.nn.functional.one_hot(torch.sum(torch.sum((target[..., None, :] - target[..., None, :, :]) ** 2.0, dim=-1, keepdim=True) > torch.linspace(start, end, steps - 1) ** 2.0, dim=-1), steps) # fmt: off | ||
|
||
mask = mask[..., None] * mask[..., None, :] | ||
|
||
output = torch.sum(torch.sum(torch.sum(torch.nn.functional.log_softmax(input, dim=-1) * target, dim=-1) * -1.0 * mask, dim=-1) / (torch.sum(mask, dim=[-1, -2]) + torch.finfo(mask.dtype).eps)[..., None], dim=-1) # fmt: off | ||
|
||
match reduction: | ||
case "mean": | ||
output = torch.mean(output) | ||
case "sum": | ||
output = torch.sum(output) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# ruff: noqa: E501 | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
import beignet.operators | ||
|
||
|
||
def frame_aligned_point_error( | ||
input: (Tensor, Tensor, Tensor), | ||
target: (Tensor, Tensor, Tensor), | ||
mask: (Tensor, Tensor), | ||
length_scale: float, | ||
pair_mask: Tensor | None = None, | ||
maximum: float | None = None, | ||
epsilon=1e-8, | ||
) -> Tensor: | ||
""" | ||
Parameters | ||
---------- | ||
input : (Tensor, Tensor, Tensor) | ||
A $3$-tuple of rotation matrices, translations, and positions. The | ||
rotation matrices must have the shape $(\\ldots, 3, 3)$, the | ||
translations must have the shape $(\\ldots, 3)$, and the positions must | ||
have the shape $(\\ldots, \text{points}, 3)$. | ||
target : (Tensor, Tensor, Tensor) | ||
A $3$-tuple of target rotation matrices, translations, and positions. | ||
The rotation matrices must have the shape $(\\ldots, 3, 3)$, the | ||
translations must have the shape $(\\ldots, 3)$, and the positions must | ||
have the shape $(\\ldots, \text{points}, 3)$. | ||
mask : (Tensor, Tensor) | ||
[*, N_frames] binary mask for the frames | ||
[..., points], position masks | ||
length_scale : float | ||
Length scale by which the loss is divided | ||
pair_mask : Tensor | None, optional | ||
[*, N_frames, N_pts] mask to use for separating intra- from inter-chain losses. | ||
maximum : float | None, optional | ||
Cutoff above which distance errors are disregarded | ||
epsilon : float, optional | ||
Small value used to regularize denominators | ||
Returns | ||
------- | ||
output : Tensor | ||
Losses for each frame of shape $(\\ldots, 3)$. | ||
""" | ||
output = torch.sqrt(torch.sum((beignet.operators.apply_transform(input[1][..., None, :, :], beignet.operators.invert_transform(input[0])) - beignet.operators.apply_transform(target[1][..., None, :, :], beignet.operators.invert_transform(target[0]), )) ** 2, dim=-1) + epsilon) # fmt: off | ||
|
||
if maximum is not None: | ||
output = torch.clamp(output, 0, maximum) | ||
|
||
output = output / length_scale * mask[0][..., None] * mask[1][..., None, :] # fmt: off | ||
|
||
if pair_mask is not None: | ||
output = torch.sum(output * pair_mask, dim=[-1, -2]) / (torch.sum(mask[0][..., None] * mask[1][..., None, :] * pair_mask, dim=[-2, -1]) + epsilon) # fmt: off | ||
else: | ||
output = torch.sum((torch.sum(output, dim=-1) / (torch.sum(mask[0], dim=-1))[..., None] + epsilon), dim=-1) / (torch.sum(mask[1], dim=-1) + epsilon) # fmt: off | ||
|
||
return output |
29 changes: 29 additions & 0 deletions
29
src/beignet/nn/functional/_per_residue_local_distance_difference_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch | ||
import torch.nn.functional | ||
from torch import Tensor | ||
|
||
|
||
def per_residue_local_distance_difference_test(input: Tensor) -> Tensor: | ||
""" | ||
Parameters | ||
---------- | ||
input : Tensor | ||
Returns | ||
------- | ||
output : Tensor | ||
""" | ||
probs = torch.nn.functional.softmax(input, dim=-1) | ||
|
||
bins = input.shape[-1] | ||
|
||
step = 1.0 / bins | ||
|
||
bounds = torch.arange(0.5 * step, 1.0, step) | ||
|
||
indexes = (1,) * len(probs.shape[:-1]) | ||
output = bounds.view(*indexes, *bounds.shape) | ||
output = probs * output | ||
output = torch.sum(output, dim=-1) | ||
|
||
return output * 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from ._apply_rotation_matrix import apply_rotation_matrix | ||
from ._apply_transform import apply_transform | ||
from ._invert_rotation_matrix import invert_rotation_matrix | ||
from ._invert_transform import invert_transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def apply_rotation_matrix( | ||
input: Tensor, | ||
rotation: Tensor, | ||
inverse: bool | None = False, | ||
) -> Tensor: | ||
r""" | ||
Rotates vectors in three-dimensional space using rotation matrices. | ||
Note | ||
---- | ||
This function interprets the rotation of the original frame to the final | ||
frame as either a projection, where it maps the components of vectors from | ||
the final frame to the original frame, or as a physical rotation, | ||
integrating the vectors into the original frame during the rotation | ||
process. Consequently, the vector components are maintained in the original | ||
frame’s perspective both before and after the rotation. | ||
Parameters | ||
---------- | ||
input : Tensor, shape (..., 3) | ||
Each vector represents a vector in three-dimensional space. The number | ||
of rotation matrices and number of vectors must follow standard | ||
broadcasting rules: either one of them equals unity or they both equal | ||
each other. | ||
rotation : Tensor, shape (..., 3, 3) | ||
Rotation matrices. | ||
inverse : bool, optional | ||
If `True` the inverse of the rotation matrices are applied to the input | ||
vectors. Default, `False`. | ||
Returns | ||
------- | ||
rotated_vectors : Tensor, shape (..., 3) | ||
Rotated vectors. | ||
""" | ||
if inverse: | ||
output = torch.einsum("ikj, ik -> ij", rotation, input) | ||
else: | ||
output = torch.einsum("ijk, ik -> ij", rotation, input) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from torch import Tensor | ||
|
||
from beignet.operators import apply_rotation_matrix | ||
|
||
|
||
def apply_transform( | ||
input: Tensor, | ||
transform: (Tensor, Tensor), | ||
inverse: bool | None = False, | ||
) -> Tensor: | ||
r""" | ||
Applies three-dimensional transforms to vectors. | ||
Note | ||
---- | ||
This function interprets the rotation of the original frame to the final | ||
frame as either a projection, where it maps the components of vectors from | ||
the final frame to the original frame, or as a physical rotation, | ||
integrating the vectors into the original frame during the rotation | ||
process. Consequently, the vector components are maintained in the original | ||
frame’s perspective both before and after the rotation. | ||
Parameters | ||
---------- | ||
input : Tensor | ||
Each vector represents a vector in three-dimensional space. The number | ||
of rotation matrices, number of translation vectors, and number of | ||
input vectors must follow standard broadcasting rules: either one of | ||
them equals unity or they both equal each other. | ||
transform : (Tensor, Tensor) | ||
Transforms represented as a pair of rotation matrices and translation | ||
vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and | ||
the translations must have the shape $(\ldots, 3)$. | ||
inverse : bool, optional | ||
If `True`, applies the inverse transformation (i.e., inverse rotation | ||
and negated translation) to the input vectors. Default, `False`. | ||
Returns | ||
------- | ||
output : Tensor | ||
Rotated and translated vectors. | ||
""" | ||
rotation, translation = transform | ||
|
||
output = apply_rotation_matrix(input, rotation, inverse=inverse) | ||
|
||
if inverse: | ||
output = output - translation | ||
else: | ||
output = output + translation | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def invert_rotation_matrix(input: Tensor) -> Tensor: | ||
r""" | ||
Invert rotation matrices. | ||
Parameters | ||
---------- | ||
input : Tensor, shape=(..., 3, 3) | ||
Rotation matrices. | ||
Returns | ||
------- | ||
output : Tensor, shape=(..., 3, 3) | ||
Inverted rotation matrices. | ||
""" | ||
return torch.transpose(input, -2, -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
import beignet.operators | ||
|
||
|
||
def invert_transform(input: (Tensor, Tensor)) -> (Tensor, Tensor): | ||
r""" | ||
Invert transforms. | ||
Parameters | ||
---------- | ||
input : (Tensor, Tensor) | ||
Transforms represented as a pair of rotation matrices and translation | ||
vectors. The rotation matrices must have the shape $(\ldots, 3, 3)$ and | ||
the translations must have the shape $(\ldots, 3)$. | ||
Returns | ||
------- | ||
output : (Tensor, Tensor) | ||
Inverted transforms represented as a pair of rotation matrices and | ||
translation vectors. The rotation matrices have the shape | ||
$(\ldots, 3, 3)$ and the translations have the shape $(\ldots, 3)$. | ||
""" | ||
rotation, translation = input | ||
|
||
rotation = beignet.operators.invert_rotation_matrix(rotation) | ||
|
||
return rotation, -rotation @ torch.squeeze( | ||
torch.unsqueeze(translation, dim=-1), dim=-1 | ||
) |