Skip to content

Commit

Permalink
per_residue_local_distance_difference_test
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 30, 2024
1 parent 61e8ef3 commit 743ef7d
Showing 1 changed file with 5 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# ruff: noqa: E501

import torch
import torch.nn.functional
from torch import Tensor
Expand All @@ -13,17 +15,10 @@ def per_residue_local_distance_difference_test(input: Tensor) -> Tensor:
-------
output : Tensor
"""
probs = torch.nn.functional.softmax(input, dim=-1)

bins = input.shape[-1]
output = torch.nn.functional.softmax(input, dim=-1)

step = 1.0 / bins
step = 1.0 / input.shape[-1]

bounds = torch.arange(0.5 * step, 1.0, step)

indexes = (1,) * len(probs.shape[:-1])
output = bounds.view(*indexes, *bounds.shape)
output = probs * output
output = torch.sum(output, dim=-1)

return output * 100
return torch.sum(output * torch.reshape(bounds, [*[1 for _ in range(len(output.shape[:-1]))], *bounds.shape]), dim=-1) * 100.0 # fmt: off

0 comments on commit 743ef7d

Please sign in to comment.