Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 13, 2024
1 parent 5c8ef4b commit aac5988
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/beignet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._msa_dataset import MSADataset
24 changes: 15 additions & 9 deletions src/beignet/datasets/_msa_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from pathlib import Path
from typing import Callable

from torch import Tensor
from torch.utils.data import Dataset

import numpy
import pooch
import torch
from torch import Tensor
from torch.utils.data import Dataset


def make_similarity_matrices(*args):
return


class MSADataset(Dataset):
def __init__(
Expand All @@ -24,7 +28,7 @@ def __init__(

if download:
pooch.retrieve(
f"https://files.ipd.uw.edu/krypton/data_unalign.npz",
"https://files.ipd.uw.edu/krypton/data_unalign.npz",
fname=f"{name}.npz",
known_hash="9cc22e381619b66fc353c079221fd02450705d4e3ee23e4e23a052b6e70a95ec",
path=root / name,
Expand All @@ -41,7 +45,7 @@ def __init__(
for subset in self.all_data.files:
data = self.all_data[subset].tolist()

# pad sequences
# pad sequences
sequences = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(data["ms"]),
0.0,
Expand All @@ -59,7 +63,7 @@ def __init__(
sizes = torch.tensor([len(seq) for seq in sequences])
all_sizes.append(sizes)

# pad alignments
# pad alignments
alignments = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(data["aln"]),
0.0,
Expand All @@ -72,10 +76,12 @@ def __init__(
],
)[alignments]

_, alignments = alignments[0], alignments[1:] # ignore first alignment
_, alignments = alignments[0], alignments[1:] # ignore first alignment
all_alignments.append(alignments)

matrices = make_similarity_matrices(sequences, reference_sequence) # TODO (Edith): make matrices
matrices = make_similarity_matrices(
sequences, reference_sequence
) # TODO (Edith): make matrices
all_matrices.append(matrices)

self.sequences = torch.stack(all_sequences, dim=1)
Expand All @@ -101,4 +107,4 @@ def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
if self.target_transform:
target = self.target_transform(target)

return inputs, target
return inputs, target

0 comments on commit aac5988

Please sign in to comment.