Skip to content

Commit

Permalink
beignet.nn.functional.torsion_angle_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 30, 2024
1 parent 5536dce commit 724e191
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/beignet/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._angle_norm_loss import angle_norm_loss
from ._torsion_angle_loss import torsion_angle_loss
25 changes: 25 additions & 0 deletions src/beignet/nn/functional/_angle_norm_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Literal

import torch
from torch import Tensor


def angle_norm_loss(
input,
reduction: Literal["mean", "sum"] | None = "mean",
) -> Tensor:
loss = torch.abs(torch.norm(input, dim=-1) - 1)

match reduction:
case "mean":
loss = torch.mean(
loss,
dim=[-1, -2],
)
case "sum":
loss = torch.sum(
loss,
dim=[-1, -2],
)

return 2.0 * loss
47 changes: 47 additions & 0 deletions src/beignet/nn/functional/_torsion_angle_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Literal

import torch
from torch import Tensor

from beignet.nn.functional import angle_norm_loss


def torsion_angle_loss(
input: Tensor,
target: (Tensor, Tensor),
reduction: Literal["mean", "sum"] | None = "mean",
) -> Tensor:
"""
Parameters
----------
input : Tensor
[*, N, 7, 2]
target : (Tensor, Tensor)
[*, N, 7, 2], [*, N, 7, 2]
reduction : str
"mean" or "sum"
"""
x = input / torch.unsqueeze(torch.norm(input, dim=-1), dim=-1)

y, z = target

loss = torch.minimum(
torch.norm(x - y, dim=-1) ** 2,
torch.norm(x - z, dim=-1) ** 2,
)

match reduction:
case "mean":
loss = torch.mean(
loss,
dim=[-1, -2],
)
case "sum":
loss = torch.sum(
loss,
dim=[-1, -2],
)

return loss + angle_norm_loss(x, reduction)

0 comments on commit 724e191

Please sign in to comment.