Skip to content

Commit

Permalink
bond_length_violation_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 1, 2024
1 parent ff46c93 commit fc16aff
Showing 1 changed file with 141 additions and 169 deletions.
310 changes: 141 additions & 169 deletions src/beignet/nn/functional/_bond_length_violation_loss.py
Original file line number Diff line number Diff line change
@@ -1,202 +1,174 @@
from typing import Sequence

import torch
from torch import Tensor

from beignet.constants import (
ADJACENT_RESIDUE_PHI_COSINE,
ADJACENT_RESIDUE_PSI_COSINE,
AMINO_ACID_3,
)


def bond_length_violation_loss(
pred_atom_positions: Tensor, # (*, N, 37/14, 3)
pred_atom_mask: Tensor, # (*, N, 37/14)
residue_index: Tensor, # (*, N)
amino_acid: Tensor, # (*, N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0,
) -> dict[str, Tensor]:
r"""
Parameters
----------
pred_atom_positions : Tensor, shape=(*, N, 37/14, 3)
Atom positions in atom37/14 representation.
pred_atom_mask : Tensor, shape=(*, N, 37/14)
Atom mask in atom37/14 representation.
residue_index : Tensor, shape=(*, N)
Residue index for given amino acid, this is assumed to be monotonically
increasing.
amino_acid : Tensor, shape=(*, N)
Amino acid type of given residue.
tolerance_factor_soft : float, optional
Soft tolerance factor measured in standard deviations of pdb
distributions. Default, 12.0.
tolerance_factor_hard : float, optional
Hard tolerance factor measured in standard deviations of pdb
distributions. Default, 12.0.
eps : float, optional
Small value to avoid division by zero. Default, 1e-6.
Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
by CA, C, N
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
by C, N, CA
* 'per_residue_loss_sum': sum of all losses for each residue
* 'per_residue_violation_mask': mask denoting all residues with violation
present.
"""
error_target_0 = ADJACENT_RESIDUE_PSI_COSINE[0]

error_target_1 = [0.014, 0.016][0]
error_target_3 = ADJACENT_RESIDUE_PHI_COSINE[0]
error_target_4 = ADJACENT_RESIDUE_PHI_COSINE[1]

# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
amino_acid[..., 1:]
== {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])}["PRO"]
)

gt_length = _gt_length(next_is_proline)
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
this_c_pos = pred_atom_positions[..., :-1, 2, :]
this_c_mask = pred_atom_mask[..., :-1, 2]
next_n_pos = pred_atom_positions[..., 1:, 0, :]
next_n_mask = pred_atom_mask[..., 1:, 0]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]

has_no_gap_mask = residue_index[..., 1:] - residue_index[..., :-1] == 1.0

bond_length_0 = _bond_length(next_n_pos, this_c_pos)

unit_vector_0 = (next_n_pos - this_c_pos) / bond_length_0[..., None]

error_0 = _error(bond_length_0, gt_length)
loss_per_residue_0 = _loss_per_residue(
error_0, _gt_stddev(next_is_proline), tolerance_factor_soft
)
loss_0 = _loss(loss_per_residue_0, this_c_mask * next_n_mask * has_no_gap_mask)

bond_length_1 = _bond_length(this_c_pos, this_ca_pos)
bond_length_2 = _bond_length(next_ca_pos, next_n_pos)

ca_c_n_cos_angle = torch.sum(
(this_ca_pos - this_c_pos) / bond_length_1[..., None] * unit_vector_0, dim=-1
)

error_1 = _error(ca_c_n_cos_angle, error_target_0)
loss_per_residue_1 = _loss_per_residue(
error_1, error_target_1, tolerance_factor_soft
)
loss_1 = _loss(
loss_per_residue_1, this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
)

c_n_ca_cos_angle = _c_n_ca_cos_angle(
unit_vector_0, (next_ca_pos - next_n_pos) / bond_length_2[..., None]
)
error_2 = _error(c_n_ca_cos_angle, error_target_3)

loss_per_residue_2 = _loss_per_residue(
error_2, error_target_4, tolerance_factor_soft
)
loss_2 = _loss(
loss_per_residue_2, this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
)
from beignet.constants import AMINO_ACID_3


# IGNORE THIS:
def f4(
dx: Tensor,
je: Tensor,
s2: Tensor,
r0: Tensor,
zc: float = 12.0,
gj: float = 12.0,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor):
pn = torch.finfo(dx.dtype).eps

me = {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])}

cv = r0[..., 1:] == me["PRO"]

ip = cv * 1.341
p3 = cv * 0.016

n5 = ~cv
km = ~cv

n5 = n5 * 1.329
km = km * 0.014

n5 = n5 + ip
km = km + p3

ow = dx[..., :-1, 2, :]
ch = dx[..., +1:, 0, :]
fj = dx[..., :-1, 1, :]
n4 = dx[..., +1:, 1, :]

k4 = je[..., :-1, 2]
uo = je[..., +1:, 0]
o6 = je[..., :-1, 1]
em = je[..., +1:, 1]

r3 = fj - ow
re = ch - n4
p6 = ow - ch

r3 = r3**2
re = re**2
p6 = p6**2

r3 = torch.sum(r3, dim=-1)
re = torch.sum(re, dim=-1)
p6 = torch.sum(p6, dim=-1)

r3 = r3 + pn
re = re + pn
p6 = p6 + pn

r3 = torch.sqrt(r3)
re = torch.sqrt(re)
p6 = torch.sqrt(p6)

mp = n5
mp = p6 - mp
mp = mp**2
mp = mp + pn
mp = torch.sqrt(mp)

zu = ch - ow
zu = zu / p6[..., None]

tn = n4 - ch
mn = fj - ow

mn = mn / r3[..., None]
tn = tn / re[..., None]

per_residue_loss_sum = _per_residue_loss_sum(
loss_per_residue_2, loss_per_residue_0, loss_per_residue_1
)
mn = mn * +zu
tn = tn * -zu

violation_mask = _per_residue_violation_mask(
[error_0, error_2, error_1],
error_target_4,
this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask,
tolerance_factor_hard,
)
mn = torch.sum(mn, dim=-1)
tn = torch.sum(tn, dim=-1)

return {
"c_n_loss_mean": loss_0,
"ca_c_n_loss_mean": loss_1,
"c_n_ca_loss_mean": loss_2,
"per_residue_loss_sum": per_residue_loss_sum,
"per_residue_violation_mask": violation_mask,
}
xa = mn + 0.4473
xn = tn + 0.5203

xa = xa**2
xn = xn**2

def _c_n_ca_cos_angle(input, other):
return torch.sum(-input * other, dim=-1)
xa = xa + pn
xn = xn + pn

xa = torch.sqrt(xa)
xn = torch.sqrt(xn)

def _gt_stddev(input):
a = [0.014, 0.016][0]
b = [0.014, 0.016][1]
return ~input * a + input * b
e1 = zc * km
f6 = zc * 0.0140
ot = zc * 0.0353

e1 = mp - e1
f6 = xa - f6
ot = xn - ot

def _gt_length(input):
a = [1.329, 1.341][0]
b = [1.329, 1.341][1]
e1 = torch.nn.functional.relu(e1)
f6 = torch.nn.functional.relu(f6)
ot = torch.nn.functional.relu(ot)

return ~input * a + input * b
qh = s2[..., :-1] == 1.0
qh = s2[..., 1:] - qh

x8 = k4 * uo
x8 = x8 * qh

def _per_residue_violation_mask(inputs: Sequence[Tensor], target, mask, temperature):
output = []
qg = o6 * k4
qg = qg * uo
qg = qg * qh

for input in inputs:
output = [*output, ((input > target * temperature) * mask)]
rc = k4 * uo
rc = rc * em
rc = rc * qh

output = torch.max(torch.stack(output, dim=-2), dim=-2)[0]
t4 = e1 * x8
vc = f6 * qg
wx = ot * rc

x = torch.nn.functional.pad(output, [0, 1])
y = torch.nn.functional.pad(output, [1, 0])
t4 = torch.sum(t4, dim=-1)
vc = torch.sum(vc, dim=-1)
wx = torch.sum(wx, dim=-1)

return torch.maximum(x, y)
ze = torch.sum(x8, dim=-1)
wf = torch.sum(qg, dim=-1)
qf = torch.sum(rc, dim=-1)

ze = ze + pn
wf = wf + pn
qf = qf + pn

def _bond_length(input, other):
output = torch.sum((other - input) ** 2, dim=-1)
t4 = t4 / ze
vc = vc / wf
wx = wx / qf

return torch.sqrt(output + torch.finfo(input.dtype).eps)
ub = ot + e1
ub = ub + f6

rv = torch.nn.functional.pad(ub, [0, 1])
dk = torch.nn.functional.pad(ub, [1, 0])

def _loss(input, mask):
output = torch.sum(input * mask, dim=-1)
nq = rv + dk
nq = nq * 0.5

return output / (torch.sum(mask, dim=-1) + torch.finfo(input.dtype).eps)
ic = k4 * uo
ic = ic * em
ic = ic * qh

x3 = []

def _error(input, target):
return torch.sqrt((input - target) ** 2 + torch.finfo(input.dtype).eps)
for zg in [mp, xn, xa]:
mw = gj * 0.0353
mw = zg > mw
mw = mw * ic

x3 = [*x3, mw]

def _loss_per_residue(input, target, temperature):
return torch.nn.functional.relu(input - target * temperature)
mh = torch.stack(x3, dim=-2)

va, _ = torch.max(mh, dim=-2)

def _per_residue_loss_sum(a, b, c):
"""
Compute a per residue loss (equally distribute the loss to both
neighbouring residues.
"""
output = a + b + c
yn = torch.nn.functional.pad(va, [0, 1])
gh = torch.nn.functional.pad(va, [1, 0])

x = torch.nn.functional.pad(output, [0, 1])
y = torch.nn.functional.pad(output, [1, 0])
nd = torch.maximum(yn, gh)

return 0.5 * (x + y)
return t4, vc, wx, nq, nd

0 comments on commit fc16aff

Please sign in to comment.