Skip to content

Commit

Permalink
sequence alignment operators
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 13, 2024
1 parent 52163c0 commit 87af20b
Show file tree
Hide file tree
Showing 102 changed files with 7,040 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
.coverage
.hypothesis/
.idea/
.ipynb_checkpoints/
.pytest_cache/
.ruff_cache/
__pycache__/
build/
dist/
notebooks/
venv/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ repos:
- id: "check-toml"
- id: "check-yaml"
repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: "v4.5.0"
rev: "v4.6.0"
- hooks:
- args:
- "--fix"
id: "ruff"
- id: "ruff-format"
repo: "https://github.com/astral-sh/ruff-pre-commit"
rev: "v0.3.5"
rev: "v0.3.7"
11 changes: 11 additions & 0 deletions docs/beignet.ops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# beignet.ops

## Geometry

### Transformations

#### Rotations

#### Translations

## Interpolation
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ test = [
"scipy",
]

[tool.ruff]
exclude = [
"./src/beignet/constants/_substitution_matrices.py",
]

[tool.ruff]
select = [
"B", # FLAKE8-BUGBEAR
Expand Down
119 changes: 119 additions & 0 deletions src/beignet/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from ._substitution_matrices import (
BLOSUM45,
BLOSUM50,
BLOSUM62,
BLOSUM80,
BLOSUM90,
BLOSUM_VOCABULARY,
PAM10,
PAM20,
PAM30,
PAM40,
PAM50,
PAM60,
PAM70,
PAM80,
PAM90,
PAM100,
PAM110,
PAM120,
PAM130,
PAM140,
PAM150,
PAM160,
PAM170,
PAM180,
PAM190,
PAM200,
PAM210,
PAM220,
PAM230,
PAM240,
PAM250,
PAM260,
PAM270,
PAM280,
PAM290,
PAM300,
PAM310,
PAM320,
PAM330,
PAM340,
PAM350,
PAM360,
PAM370,
PAM380,
PAM390,
PAM400,
PAM410,
PAM420,
PAM430,
PAM440,
PAM450,
PAM460,
PAM470,
PAM480,
PAM490,
PAM500,
PAM_VOCABULARY,
)

__all__ = [
"BLOSUM45",
"BLOSUM50",
"BLOSUM62",
"BLOSUM80",
"BLOSUM90",
"BLOSUM_VOCABULARY",
"PAM10",
"PAM20",
"PAM30",
"PAM40",
"PAM50",
"PAM60",
"PAM70",
"PAM80",
"PAM90",
"PAM100",
"PAM110",
"PAM120",
"PAM130",
"PAM140",
"PAM150",
"PAM160",
"PAM170",
"PAM180",
"PAM190",
"PAM200",
"PAM210",
"PAM220",
"PAM230",
"PAM240",
"PAM250",
"PAM260",
"PAM270",
"PAM280",
"PAM290",
"PAM300",
"PAM310",
"PAM320",
"PAM330",
"PAM340",
"PAM350",
"PAM360",
"PAM370",
"PAM380",
"PAM390",
"PAM400",
"PAM410",
"PAM420",
"PAM430",
"PAM440",
"PAM450",
"PAM460",
"PAM470",
"PAM480",
"PAM490",
"PAM500",
"PAM_VOCABULARY",
]
1,434 changes: 1,434 additions & 0 deletions src/beignet/constants/_substitution_matrices.py

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions src/beignet/datasets/_msa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from pathlib import Path
from typing import Callable

from torch import Tensor
from torch.utils.data import Dataset

import numpy
import pooch
import torch

class MSADataset(Dataset):
def __init__(
self,
root: str | Path,
*,
download: bool = False,
transform: Callable | None = None,
target_transform: Callable | None = None,
):
if isinstance(root, str):
root = Path(root)

name = self.__class__.__name__

if download:
pooch.retrieve(
f"https://files.ipd.uw.edu/krypton/data_unalign.npz",
fname=f"{name}.npz",
known_hash="9cc22e381619b66fc353c079221fd02450705d4e3ee23e4e23a052b6e70a95ec",
path=root / name,
)

self.all_data = numpy.load(root / name / f"{name}.npz", allow_pickle=True)

all_sequences = []
all_alignments = []
all_sizes = []
all_matrices = []

# process each subset
for subset in self.all_data.files:
data = self.all_data[subset].tolist()

# pad sequences
sequences = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(data["ms"]),
0.0,
)
sequences = torch.concatenate(
[
torch.eye(torch.max(sequences) + 1),
torch.zeros([1, torch.max(sequences) + 1]),
],
)[sequences]

reference_sequence, sequences = sequences[0].unsqueeze(0), sequences[1:]
all_sequences.append(sequences)

sizes = torch.tensor([len(seq) for seq in sequences])
all_sizes.append(sizes)

# pad alignments
alignments = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(data["aln"]),
0.0,
)

alignments = torch.concatenate(
[
torch.eye(torch.max(alignments) + 1),
torch.zeros([1, torch.max(alignments) + 1]),
],
)[alignments]

_, alignments = alignments[0], alignments[1:] # ignore first alignment
all_alignments.append(alignments)

matrices = make_similarity_matrices(sequences, reference_sequence) # TODO (Edith): make matrices
all_matrices.append(matrices)

self.sequences = torch.stack(all_sequences, dim=1)
self.alignments = torch.stack(all_alignments, dim=1)
self.sizes = torch.stack(all_sizes, dim=1)
self.matrices = torch.stack(all_matrices, dim=1)

self.transform = transform

self.target_transform = target_transform

def __len__(self):
return self.sequences.size(0)

def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
inputs = self.matrices[index], self.sizes[index]

if self.transform:
inputs = self.transform(*inputs)

target = self.alignments[index]

if self.target_transform:
target = self.target_transform(target)

return inputs, target
1 change: 1 addition & 0 deletions src/beignet/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._msa_lightning_module import MSALightningModule
27 changes: 27 additions & 0 deletions src/beignet/lightning/_msa_lightning_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from lightning import LightningModule

from beignet.nn import MSA


class MSALightningModule(LightningModule):
def __init__(
self,
in_channels: int,
out_channels: int = 512,
kernel_size: int = 18,
*,
gap_penalty: float = 0.0,
temperature: float = 1.0,
):
super().__init__()

self.module = MSA(
in_channels,
out_channels,
kernel_size,
gap_penalty=gap_penalty,
temperature=temperature,
)

def forward(self, x):
return self.model(x)
1 change: 1 addition & 0 deletions src/beignet/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._msa import MSA
56 changes: 56 additions & 0 deletions src/beignet/nn/_msa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch import Tensor
from torch.nn import Conv1d, Module

import beignet.operators


class MSA(Module):
def __init__(
self,
in_channels: int,
out_channels: int = 512,
kernel_size: int = 18,
*,
gap_penalty: float = 0.0,
temperature: float = 1.0,
):
super().__init__()

self.gap_penalty = gap_penalty

self.temperature = temperature

self.embedding = Conv1d(
in_channels,
out_channels,
kernel_size,
padding="same",
)

def forward(self, inputs: (Tensor, Tensor)) -> Tensor:
matrices, shapes = inputs

embedding = self.embedding(matrices)

embedding = embedding @ embedding[0].T

output = beignet.operators.needleman_wunsch(
embedding,
shapes,
gap_penalty=self.gap_penalty,
temperature=self.temperature,
)

return torch.einsum(
"ja, nij -> nia",
torch.mean(
torch.einsum(
"nia, nij -> nja",
matrices,
output,
),
dim=0,
),
output,
)
7 changes: 7 additions & 0 deletions src/beignet/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._needleman_wunsch import needleman_wunsch
from ._smith_waterman import smith_waterman

__all__ = [
"needleman_wunsch",
"smith_waterman",
]
Loading

0 comments on commit 87af20b

Please sign in to comment.