From fc16affad2826b5720d98e5e92a05b82bab17a97 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Wed, 1 May 2024 11:35:44 -0400 Subject: [PATCH] bond_length_violation_loss --- .../functional/_bond_length_violation_loss.py | 310 ++++++++---------- 1 file changed, 141 insertions(+), 169 deletions(-) diff --git a/src/beignet/nn/functional/_bond_length_violation_loss.py b/src/beignet/nn/functional/_bond_length_violation_loss.py index fb837b89f6..e841bbe59c 100644 --- a/src/beignet/nn/functional/_bond_length_violation_loss.py +++ b/src/beignet/nn/functional/_bond_length_violation_loss.py @@ -1,202 +1,174 @@ -from typing import Sequence - import torch from torch import Tensor -from beignet.constants import ( - ADJACENT_RESIDUE_PHI_COSINE, - ADJACENT_RESIDUE_PSI_COSINE, - AMINO_ACID_3, -) - - -def bond_length_violation_loss( - pred_atom_positions: Tensor, # (*, N, 37/14, 3) - pred_atom_mask: Tensor, # (*, N, 37/14) - residue_index: Tensor, # (*, N) - amino_acid: Tensor, # (*, N) - tolerance_factor_soft=12.0, - tolerance_factor_hard=12.0, -) -> dict[str, Tensor]: - r""" - Parameters - ---------- - pred_atom_positions : Tensor, shape=(*, N, 37/14, 3) - Atom positions in atom37/14 representation. - pred_atom_mask : Tensor, shape=(*, N, 37/14) - Atom mask in atom37/14 representation. - residue_index : Tensor, shape=(*, N) - Residue index for given amino acid, this is assumed to be monotonically - increasing. - amino_acid : Tensor, shape=(*, N) - Amino acid type of given residue. - tolerance_factor_soft : float, optional - Soft tolerance factor measured in standard deviations of pdb - distributions. Default, 12.0. - tolerance_factor_hard : float, optional - Hard tolerance factor measured in standard deviations of pdb - distributions. Default, 12.0. - eps : float, optional - Small value to avoid division by zero. Default, 1e-6. - - Flat-bottom loss to penalize structural violations between residues. - - This is a loss penalizing any violation of the geometry around the peptide - bond between consecutive amino acids. This loss corresponds to - Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. - - Returns: - Dict containing: - * 'c_n_loss_mean': Loss for peptide bond length violations - * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned - by CA, C, N - * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned - by C, N, CA - * 'per_residue_loss_sum': sum of all losses for each residue - * 'per_residue_violation_mask': mask denoting all residues with violation - present. - """ - error_target_0 = ADJACENT_RESIDUE_PSI_COSINE[0] - - error_target_1 = [0.014, 0.016][0] - error_target_3 = ADJACENT_RESIDUE_PHI_COSINE[0] - error_target_4 = ADJACENT_RESIDUE_PHI_COSINE[1] - - # The C-N bond to proline has slightly different length because of the ring. - next_is_proline = ( - amino_acid[..., 1:] - == {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])}["PRO"] - ) - - gt_length = _gt_length(next_is_proline) - # Get the positions of the relevant backbone atoms. - this_ca_pos = pred_atom_positions[..., :-1, 1, :] - this_ca_mask = pred_atom_mask[..., :-1, 1] - this_c_pos = pred_atom_positions[..., :-1, 2, :] - this_c_mask = pred_atom_mask[..., :-1, 2] - next_n_pos = pred_atom_positions[..., 1:, 0, :] - next_n_mask = pred_atom_mask[..., 1:, 0] - next_ca_pos = pred_atom_positions[..., 1:, 1, :] - next_ca_mask = pred_atom_mask[..., 1:, 1] - - has_no_gap_mask = residue_index[..., 1:] - residue_index[..., :-1] == 1.0 - - bond_length_0 = _bond_length(next_n_pos, this_c_pos) - - unit_vector_0 = (next_n_pos - this_c_pos) / bond_length_0[..., None] - - error_0 = _error(bond_length_0, gt_length) - loss_per_residue_0 = _loss_per_residue( - error_0, _gt_stddev(next_is_proline), tolerance_factor_soft - ) - loss_0 = _loss(loss_per_residue_0, this_c_mask * next_n_mask * has_no_gap_mask) - - bond_length_1 = _bond_length(this_c_pos, this_ca_pos) - bond_length_2 = _bond_length(next_ca_pos, next_n_pos) - - ca_c_n_cos_angle = torch.sum( - (this_ca_pos - this_c_pos) / bond_length_1[..., None] * unit_vector_0, dim=-1 - ) - - error_1 = _error(ca_c_n_cos_angle, error_target_0) - loss_per_residue_1 = _loss_per_residue( - error_1, error_target_1, tolerance_factor_soft - ) - loss_1 = _loss( - loss_per_residue_1, this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask - ) - - c_n_ca_cos_angle = _c_n_ca_cos_angle( - unit_vector_0, (next_ca_pos - next_n_pos) / bond_length_2[..., None] - ) - error_2 = _error(c_n_ca_cos_angle, error_target_3) - - loss_per_residue_2 = _loss_per_residue( - error_2, error_target_4, tolerance_factor_soft - ) - loss_2 = _loss( - loss_per_residue_2, this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask - ) +from beignet.constants import AMINO_ACID_3 + + +# IGNORE THIS: +def f4( + dx: Tensor, + je: Tensor, + s2: Tensor, + r0: Tensor, + zc: float = 12.0, + gj: float = 12.0, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor): + pn = torch.finfo(dx.dtype).eps + + me = {k: v for v, k in enumerate([*AMINO_ACID_3, "UNK"])} + + cv = r0[..., 1:] == me["PRO"] + + ip = cv * 1.341 + p3 = cv * 0.016 + + n5 = ~cv + km = ~cv + + n5 = n5 * 1.329 + km = km * 0.014 + + n5 = n5 + ip + km = km + p3 + + ow = dx[..., :-1, 2, :] + ch = dx[..., +1:, 0, :] + fj = dx[..., :-1, 1, :] + n4 = dx[..., +1:, 1, :] + + k4 = je[..., :-1, 2] + uo = je[..., +1:, 0] + o6 = je[..., :-1, 1] + em = je[..., +1:, 1] + + r3 = fj - ow + re = ch - n4 + p6 = ow - ch + + r3 = r3**2 + re = re**2 + p6 = p6**2 + + r3 = torch.sum(r3, dim=-1) + re = torch.sum(re, dim=-1) + p6 = torch.sum(p6, dim=-1) + + r3 = r3 + pn + re = re + pn + p6 = p6 + pn + + r3 = torch.sqrt(r3) + re = torch.sqrt(re) + p6 = torch.sqrt(p6) + + mp = n5 + mp = p6 - mp + mp = mp**2 + mp = mp + pn + mp = torch.sqrt(mp) + + zu = ch - ow + zu = zu / p6[..., None] + + tn = n4 - ch + mn = fj - ow + + mn = mn / r3[..., None] + tn = tn / re[..., None] - per_residue_loss_sum = _per_residue_loss_sum( - loss_per_residue_2, loss_per_residue_0, loss_per_residue_1 - ) + mn = mn * +zu + tn = tn * -zu - violation_mask = _per_residue_violation_mask( - [error_0, error_2, error_1], - error_target_4, - this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask, - tolerance_factor_hard, - ) + mn = torch.sum(mn, dim=-1) + tn = torch.sum(tn, dim=-1) - return { - "c_n_loss_mean": loss_0, - "ca_c_n_loss_mean": loss_1, - "c_n_ca_loss_mean": loss_2, - "per_residue_loss_sum": per_residue_loss_sum, - "per_residue_violation_mask": violation_mask, - } + xa = mn + 0.4473 + xn = tn + 0.5203 + xa = xa**2 + xn = xn**2 -def _c_n_ca_cos_angle(input, other): - return torch.sum(-input * other, dim=-1) + xa = xa + pn + xn = xn + pn + xa = torch.sqrt(xa) + xn = torch.sqrt(xn) -def _gt_stddev(input): - a = [0.014, 0.016][0] - b = [0.014, 0.016][1] - return ~input * a + input * b + e1 = zc * km + f6 = zc * 0.0140 + ot = zc * 0.0353 + e1 = mp - e1 + f6 = xa - f6 + ot = xn - ot -def _gt_length(input): - a = [1.329, 1.341][0] - b = [1.329, 1.341][1] + e1 = torch.nn.functional.relu(e1) + f6 = torch.nn.functional.relu(f6) + ot = torch.nn.functional.relu(ot) - return ~input * a + input * b + qh = s2[..., :-1] == 1.0 + qh = s2[..., 1:] - qh + x8 = k4 * uo + x8 = x8 * qh -def _per_residue_violation_mask(inputs: Sequence[Tensor], target, mask, temperature): - output = [] + qg = o6 * k4 + qg = qg * uo + qg = qg * qh - for input in inputs: - output = [*output, ((input > target * temperature) * mask)] + rc = k4 * uo + rc = rc * em + rc = rc * qh - output = torch.max(torch.stack(output, dim=-2), dim=-2)[0] + t4 = e1 * x8 + vc = f6 * qg + wx = ot * rc - x = torch.nn.functional.pad(output, [0, 1]) - y = torch.nn.functional.pad(output, [1, 0]) + t4 = torch.sum(t4, dim=-1) + vc = torch.sum(vc, dim=-1) + wx = torch.sum(wx, dim=-1) - return torch.maximum(x, y) + ze = torch.sum(x8, dim=-1) + wf = torch.sum(qg, dim=-1) + qf = torch.sum(rc, dim=-1) + ze = ze + pn + wf = wf + pn + qf = qf + pn -def _bond_length(input, other): - output = torch.sum((other - input) ** 2, dim=-1) + t4 = t4 / ze + vc = vc / wf + wx = wx / qf - return torch.sqrt(output + torch.finfo(input.dtype).eps) + ub = ot + e1 + ub = ub + f6 + rv = torch.nn.functional.pad(ub, [0, 1]) + dk = torch.nn.functional.pad(ub, [1, 0]) -def _loss(input, mask): - output = torch.sum(input * mask, dim=-1) + nq = rv + dk + nq = nq * 0.5 - return output / (torch.sum(mask, dim=-1) + torch.finfo(input.dtype).eps) + ic = k4 * uo + ic = ic * em + ic = ic * qh + x3 = [] -def _error(input, target): - return torch.sqrt((input - target) ** 2 + torch.finfo(input.dtype).eps) + for zg in [mp, xn, xa]: + mw = gj * 0.0353 + mw = zg > mw + mw = mw * ic + x3 = [*x3, mw] -def _loss_per_residue(input, target, temperature): - return torch.nn.functional.relu(input - target * temperature) + mh = torch.stack(x3, dim=-2) + va, _ = torch.max(mh, dim=-2) -def _per_residue_loss_sum(a, b, c): - """ - Compute a per residue loss (equally distribute the loss to both - neighbouring residues. - """ - output = a + b + c + yn = torch.nn.functional.pad(va, [0, 1]) + gh = torch.nn.functional.pad(va, [1, 0]) - x = torch.nn.functional.pad(output, [0, 1]) - y = torch.nn.functional.pad(output, [1, 0]) + nd = torch.maximum(yn, gh) - return 0.5 * (x + y) + return t4, vc, wx, nq, nd