Skip to content

Commit

Permalink
sequence alignment operators
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 17, 2024
1 parent 3e4ffd0 commit 1891185
Show file tree
Hide file tree
Showing 7 changed files with 403 additions and 2 deletions.
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- hooks:
- id: "check-toml"
- id: "check-yaml"
repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: "v4.5.0"
- hooks:
- args:
- "--fix"
id: "ruff"
- id: "ruff-format"
repo: "https://github.com/astral-sh/ruff-pre-commit"
rev: "v0.3.5"
21 changes: 19 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,25 @@ name = "beignet"
readme = "README.md"
requires-python = ">=3.10"

[tool.ruff.format]
docstring-code-format = true
[project.optional-dependencies]
test = [
"hypothesis",
"pytest",
]

[tool.ruff]
select = [
"B", # FLAKE8-BUGBEAR
"E", # PYCODESTYLE ERRORS
"F", # PYFLAKES
"I", # ISORT
"W", # PYCODESTYLE WARNINGS
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = [
"F401", # MODULE IMPORTED BUT UNUSED
]

[tool.setuptools_scm]
local_scheme = "no-local-version"
Empty file.
193 changes: 193 additions & 0 deletions src/beignet/operators/_needleman_wunsch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional
from torch import Tensor


def needleman_wunsch(
input: Tensor,
lengths: Tuple[int, int],
*,
gap: float = 0.0,
temperature: float = 1.0,
out: Optional[Tensor] = None,
):
"""
Compute the Needleman-Wunsch alignment score for two sequences.
The Needleman-Wunsch algorithm is a global sequence alignment method used
to identify regions of similarity between two sequences.
Parameters
----------
input : Tensor
The similarity matrix of the two sequences.
lengths : Sequence[int, int]
A sequence containing the lengths of the two sequences being aligned.
gap : float, optional
The penalty for creating a gap in alignment. Default is 0.
temperature : float, optional
Scaling factor to control the sharpness of the score distribution.
Default is 1.0.
out : Tensor, optional
Output tensor
Returns
-------
Tensor
Needleman-Wunsch alignment score for the given sequences.
"""
x = torch.nn.functional.pad(input, [1, 0, 1, 0])

i = torch.add(
torch.subtract(
torch.arange(x.shape[1])[None, :],
torch.flip(
torch.arange(x.shape[0]),
dims=[0],
)[:, None],
),
torch.subtract(
torch.tensor(x.shape[0]),
torch.tensor(1),
),
)

j = torch.floor_divide(
torch.add(
torch.flip(
torch.arange(x.shape[0]),
dims=[0],
)[:, None],
torch.arange(x.shape[1])[None, :],
),
2,
)

n = (x.shape[0] + x.shape[1]) - 1
m = (x.shape[0] + x.shape[1]) // 2

x_a = torch.zeros([n, m])

x_a[i, j] = torch.concatenate(
[
torch.concatenate(
[
torch.zeros([1, 1]),
torch.multiply(
torch.arange(1, x.shape[1]).view(1, -1),
gap,
),
],
dim=1,
),
torch.concatenate(
[
torch.zeros([x.shape[0] - 1, x.shape[1] - 1]),
torch.multiply(
torch.arange(1, x.shape[0]).view(-1, 1),
gap,
),
],
dim=1,
),
],
dim=0,
)

x_b = torch.zeros([n, m])

x_b[i, j] = torch.nn.functional.pad(
torch.multiply(
torch.less(
torch.arange(input.shape[0]),
lengths[0],
)[:, None],
torch.less(
torch.arange(input.shape[1]),
lengths[1],
)[None, :],
),
[1, 0, 1, 0],
).to(x.dtype)

x_c = torch.fmod(
torch.add(
torch.arange(n),
torch.fmod(
torch.tensor(x.shape[0]),
2,
),
),
2,
)

x_d = torch.zeros([n, m])

x_d[i, j] = x

previous = torch.zeros([m]), torch.zeros([m])

scores = []

for a, b, c, d in zip(x_a, x_b, x_c, x_d, strict=False):
current = torch.add(
torch.multiply(
torch.multiply(
torch.special.logsumexp(
torch.divide(
torch.stack(
[
torch.add(
previous[0],
d,
),
torch.add(
previous[1],
gap,
),
torch.add(
torch.add(
torch.multiply(
torch.nn.functional.pad(
previous[1][:-1],
[1, 0],
),
c,
),
torch.multiply(
torch.nn.functional.pad(
previous[1][1:],
[0, 1],
),
torch.subtract(
torch.tensor(1),
c,
),
),
),
gap,
),
],
),
temperature,
),
dim=0,
),
temperature,
),
b,
),
a,
)

previous = previous[1], current

scores = [*scores, current]

return torch.stack(scores, out=out)[i, j][lengths[0], lengths[1]]
Loading

0 comments on commit 1891185

Please sign in to comment.