-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
141 additions
and
169 deletions.
There are no files selected for viewing
310 changes: 141 additions & 169 deletions
310
src/beignet/nn/functional/_bond_length_violation_loss.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,202 +1,174 @@ | ||
from typing import Sequence | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from beignet.constants import ( | ||
ADJACENT_RESIDUE_PHI_COSINE, | ||
ADJACENT_RESIDUE_PSI_COSINE, | ||
AMINO_ACID_3, | ||
) | ||
|
||
|
||
def bond_length_violation_loss( | ||
pred_atom_positions: Tensor, # (*, N, 37/14, 3) | ||
pred_atom_mask: Tensor, # (*, N, 37/14) | ||
residue_index: Tensor, # (*, N) | ||
amino_acid: Tensor, # (*, N) | ||
tolerance_factor_soft=12.0, | ||
tolerance_factor_hard=12.0, | ||
) -> dict[str, Tensor]: | ||
r""" | ||
Parameters | ||
---------- | ||
pred_atom_positions : Tensor, shape=(*, N, 37/14, 3) | ||
Atom positions in atom37/14 representation. | ||
pred_atom_mask : Tensor, shape=(*, N, 37/14) | ||
Atom mask in atom37/14 representation. | ||
residue_index : Tensor, shape=(*, N) | ||
Residue index for given amino acid, this is assumed to be monotonically | ||
increasing. | ||
amino_acid : Tensor, shape=(*, N) | ||
Amino acid type of given residue. | ||
tolerance_factor_soft : float, optional | ||
Soft tolerance factor measured in standard deviations of pdb | ||
distributions. Default, 12.0. | ||
tolerance_factor_hard : float, optional | ||
Hard tolerance factor measured in standard deviations of pdb | ||
distributions. Default, 12.0. | ||
eps : float, optional | ||
Small value to avoid division by zero. Default, 1e-6. | ||
Flat-bottom loss to penalize structural violations between residues. | ||
This is a loss penalizing any violation of the geometry around the peptide | ||
bond between consecutive amino acids. This loss corresponds to | ||
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. | ||
Returns: | ||
Dict containing: | ||
* 'c_n_loss_mean': Loss for peptide bond length violations | ||
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned | ||
by CA, C, N | ||
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned | ||
by C, N, CA | ||
* 'per_residue_loss_sum': sum of all losses for each residue | ||
* 'per_residue_violation_mask': mask denoting all residues with violation | ||
present. | ||
""" | ||
error_target_0 = ADJACENT_RESIDUE_PSI_COSINE[0] | ||
|
||
error_target_1 = [0.014, 0.016][0] | ||
error_target_3 = ADJACENT_RESIDUE_PHI_COSINE[0] | ||
error_target_4 = ADJACENT_RESIDUE_PHI_COSINE[1] | ||
|
||
# The C-N bond to proline has slightly different length because of the ring. | ||
next_is_proline = ( | ||
amino_acid[..., 1:] | ||
== {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])}["PRO"] | ||
) | ||
|
||
gt_length = _gt_length(next_is_proline) | ||
# Get the positions of the relevant backbone atoms. | ||
this_ca_pos = pred_atom_positions[..., :-1, 1, :] | ||
this_ca_mask = pred_atom_mask[..., :-1, 1] | ||
this_c_pos = pred_atom_positions[..., :-1, 2, :] | ||
this_c_mask = pred_atom_mask[..., :-1, 2] | ||
next_n_pos = pred_atom_positions[..., 1:, 0, :] | ||
next_n_mask = pred_atom_mask[..., 1:, 0] | ||
next_ca_pos = pred_atom_positions[..., 1:, 1, :] | ||
next_ca_mask = pred_atom_mask[..., 1:, 1] | ||
|
||
has_no_gap_mask = residue_index[..., 1:] - residue_index[..., :-1] == 1.0 | ||
|
||
bond_length_0 = _bond_length(next_n_pos, this_c_pos) | ||
|
||
unit_vector_0 = (next_n_pos - this_c_pos) / bond_length_0[..., None] | ||
|
||
error_0 = _error(bond_length_0, gt_length) | ||
loss_per_residue_0 = _loss_per_residue( | ||
error_0, _gt_stddev(next_is_proline), tolerance_factor_soft | ||
) | ||
loss_0 = _loss(loss_per_residue_0, this_c_mask * next_n_mask * has_no_gap_mask) | ||
|
||
bond_length_1 = _bond_length(this_c_pos, this_ca_pos) | ||
bond_length_2 = _bond_length(next_ca_pos, next_n_pos) | ||
|
||
ca_c_n_cos_angle = torch.sum( | ||
(this_ca_pos - this_c_pos) / bond_length_1[..., None] * unit_vector_0, dim=-1 | ||
) | ||
|
||
error_1 = _error(ca_c_n_cos_angle, error_target_0) | ||
loss_per_residue_1 = _loss_per_residue( | ||
error_1, error_target_1, tolerance_factor_soft | ||
) | ||
loss_1 = _loss( | ||
loss_per_residue_1, this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask | ||
) | ||
|
||
c_n_ca_cos_angle = _c_n_ca_cos_angle( | ||
unit_vector_0, (next_ca_pos - next_n_pos) / bond_length_2[..., None] | ||
) | ||
error_2 = _error(c_n_ca_cos_angle, error_target_3) | ||
|
||
loss_per_residue_2 = _loss_per_residue( | ||
error_2, error_target_4, tolerance_factor_soft | ||
) | ||
loss_2 = _loss( | ||
loss_per_residue_2, this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask | ||
) | ||
from beignet.constants import AMINO_ACID_3 | ||
|
||
|
||
# IGNORE THIS: | ||
def f4( | ||
dx: Tensor, | ||
je: Tensor, | ||
s2: Tensor, | ||
r0: Tensor, | ||
zc: float = 12.0, | ||
gj: float = 12.0, | ||
) -> (Tensor, Tensor, Tensor, Tensor, Tensor): | ||
pn = torch.finfo(dx.dtype).eps | ||
|
||
me = {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])} | ||
|
||
cv = r0[..., 1:] == me["PRO"] | ||
|
||
ip = cv * 1.341 | ||
p3 = cv * 0.016 | ||
|
||
n5 = ~cv | ||
km = ~cv | ||
|
||
n5 = n5 * 1.329 | ||
km = km * 0.014 | ||
|
||
n5 = n5 + ip | ||
km = km + p3 | ||
|
||
ow = dx[..., :-1, 2, :] | ||
ch = dx[..., +1:, 0, :] | ||
fj = dx[..., :-1, 1, :] | ||
n4 = dx[..., +1:, 1, :] | ||
|
||
k4 = je[..., :-1, 2] | ||
uo = je[..., +1:, 0] | ||
o6 = je[..., :-1, 1] | ||
em = je[..., +1:, 1] | ||
|
||
r3 = fj - ow | ||
re = ch - n4 | ||
p6 = ow - ch | ||
|
||
r3 = r3**2 | ||
re = re**2 | ||
p6 = p6**2 | ||
|
||
r3 = torch.sum(r3, dim=-1) | ||
re = torch.sum(re, dim=-1) | ||
p6 = torch.sum(p6, dim=-1) | ||
|
||
r3 = r3 + pn | ||
re = re + pn | ||
p6 = p6 + pn | ||
|
||
r3 = torch.sqrt(r3) | ||
re = torch.sqrt(re) | ||
p6 = torch.sqrt(p6) | ||
|
||
mp = n5 | ||
mp = p6 - mp | ||
mp = mp**2 | ||
mp = mp + pn | ||
mp = torch.sqrt(mp) | ||
|
||
zu = ch - ow | ||
zu = zu / p6[..., None] | ||
|
||
tn = n4 - ch | ||
mn = fj - ow | ||
|
||
mn = mn / r3[..., None] | ||
tn = tn / re[..., None] | ||
|
||
per_residue_loss_sum = _per_residue_loss_sum( | ||
loss_per_residue_2, loss_per_residue_0, loss_per_residue_1 | ||
) | ||
mn = mn * +zu | ||
tn = tn * -zu | ||
|
||
violation_mask = _per_residue_violation_mask( | ||
[error_0, error_2, error_1], | ||
error_target_4, | ||
this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask, | ||
tolerance_factor_hard, | ||
) | ||
mn = torch.sum(mn, dim=-1) | ||
tn = torch.sum(tn, dim=-1) | ||
|
||
return { | ||
"c_n_loss_mean": loss_0, | ||
"ca_c_n_loss_mean": loss_1, | ||
"c_n_ca_loss_mean": loss_2, | ||
"per_residue_loss_sum": per_residue_loss_sum, | ||
"per_residue_violation_mask": violation_mask, | ||
} | ||
xa = mn + 0.4473 | ||
xn = tn + 0.5203 | ||
|
||
xa = xa**2 | ||
xn = xn**2 | ||
|
||
def _c_n_ca_cos_angle(input, other): | ||
return torch.sum(-input * other, dim=-1) | ||
xa = xa + pn | ||
xn = xn + pn | ||
|
||
xa = torch.sqrt(xa) | ||
xn = torch.sqrt(xn) | ||
|
||
def _gt_stddev(input): | ||
a = [0.014, 0.016][0] | ||
b = [0.014, 0.016][1] | ||
return ~input * a + input * b | ||
e1 = zc * km | ||
f6 = zc * 0.0140 | ||
ot = zc * 0.0353 | ||
|
||
e1 = mp - e1 | ||
f6 = xa - f6 | ||
ot = xn - ot | ||
|
||
def _gt_length(input): | ||
a = [1.329, 1.341][0] | ||
b = [1.329, 1.341][1] | ||
e1 = torch.nn.functional.relu(e1) | ||
f6 = torch.nn.functional.relu(f6) | ||
ot = torch.nn.functional.relu(ot) | ||
|
||
return ~input * a + input * b | ||
qh = s2[..., :-1] == 1.0 | ||
qh = s2[..., 1:] - qh | ||
|
||
x8 = k4 * uo | ||
x8 = x8 * qh | ||
|
||
def _per_residue_violation_mask(inputs: Sequence[Tensor], target, mask, temperature): | ||
output = [] | ||
qg = o6 * k4 | ||
qg = qg * uo | ||
qg = qg * qh | ||
|
||
for input in inputs: | ||
output = [*output, ((input > target * temperature) * mask)] | ||
rc = k4 * uo | ||
rc = rc * em | ||
rc = rc * qh | ||
|
||
output = torch.max(torch.stack(output, dim=-2), dim=-2)[0] | ||
t4 = e1 * x8 | ||
vc = f6 * qg | ||
wx = ot * rc | ||
|
||
x = torch.nn.functional.pad(output, [0, 1]) | ||
y = torch.nn.functional.pad(output, [1, 0]) | ||
t4 = torch.sum(t4, dim=-1) | ||
vc = torch.sum(vc, dim=-1) | ||
wx = torch.sum(wx, dim=-1) | ||
|
||
return torch.maximum(x, y) | ||
ze = torch.sum(x8, dim=-1) | ||
wf = torch.sum(qg, dim=-1) | ||
qf = torch.sum(rc, dim=-1) | ||
|
||
ze = ze + pn | ||
wf = wf + pn | ||
qf = qf + pn | ||
|
||
def _bond_length(input, other): | ||
output = torch.sum((other - input) ** 2, dim=-1) | ||
t4 = t4 / ze | ||
vc = vc / wf | ||
wx = wx / qf | ||
|
||
return torch.sqrt(output + torch.finfo(input.dtype).eps) | ||
ub = ot + e1 | ||
ub = ub + f6 | ||
|
||
rv = torch.nn.functional.pad(ub, [0, 1]) | ||
dk = torch.nn.functional.pad(ub, [1, 0]) | ||
|
||
def _loss(input, mask): | ||
output = torch.sum(input * mask, dim=-1) | ||
nq = rv + dk | ||
nq = nq * 0.5 | ||
|
||
return output / (torch.sum(mask, dim=-1) + torch.finfo(input.dtype).eps) | ||
ic = k4 * uo | ||
ic = ic * em | ||
ic = ic * qh | ||
|
||
x3 = [] | ||
|
||
def _error(input, target): | ||
return torch.sqrt((input - target) ** 2 + torch.finfo(input.dtype).eps) | ||
for zg in [mp, xn, xa]: | ||
mw = gj * 0.0353 | ||
mw = zg > mw | ||
mw = mw * ic | ||
|
||
x3 = [*x3, mw] | ||
|
||
def _loss_per_residue(input, target, temperature): | ||
return torch.nn.functional.relu(input - target * temperature) | ||
mh = torch.stack(x3, dim=-2) | ||
|
||
va, _ = torch.max(mh, dim=-2) | ||
|
||
def _per_residue_loss_sum(a, b, c): | ||
""" | ||
Compute a per residue loss (equally distribute the loss to both | ||
neighbouring residues. | ||
""" | ||
output = a + b + c | ||
yn = torch.nn.functional.pad(va, [0, 1]) | ||
gh = torch.nn.functional.pad(va, [1, 0]) | ||
|
||
x = torch.nn.functional.pad(output, [0, 1]) | ||
y = torch.nn.functional.pad(output, [1, 0]) | ||
nd = torch.maximum(yn, gh) | ||
|
||
return 0.5 * (x + y) | ||
return t4, vc, wx, nq, nd |