diff --git a/src/beignet/datasets/__init__.py b/src/beignet/datasets/__init__.py index b3cb555401..64765ebe84 100644 --- a/src/beignet/datasets/__init__.py +++ b/src/beignet/datasets/__init__.py @@ -1 +1 @@ -from ._msa_dataset import MSADataset +from ._smurf_dataset import SMURFDataset diff --git a/src/beignet/datasets/_msa_dataset.py b/src/beignet/datasets/_msa_dataset.py deleted file mode 100644 index f5481cac47..0000000000 --- a/src/beignet/datasets/_msa_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -from pathlib import Path -from typing import Callable - -import numpy -import pooch -import torch -from torch import Tensor -from torch.utils.data import Dataset - - -def make_similarity_matrices(*args): - return - - -class MSADataset(Dataset): - def __init__( - self, - root: str | Path, - *, - download: bool = False, - transform: Callable | None = None, - target_transform: Callable | None = None, - ): - if isinstance(root, str): - root = Path(root) - - name = self.__class__.__name__ - - if download: - pooch.retrieve( - "https://files.ipd.uw.edu/krypton/data_unalign.npz", - fname=f"{name}.npz", - known_hash="9cc22e381619b66fc353c079221fd02450705d4e3ee23e4e23a052b6e70a95ec", - path=root / name, - ) - - self.all_data = numpy.load(root / name / f"{name}.npz", allow_pickle=True) - - all_sequences = [] - all_alignments = [] - all_sizes = [] - all_matrices = [] - - # process each subset - for subset in self.all_data.files: - data = self.all_data[subset].tolist() - - # pad sequences - sequences = torch.nested.to_padded_tensor( - torch.nested.nested_tensor(data["ms"]), - 0.0, - ) - sequences = torch.concatenate( - [ - torch.eye(torch.max(sequences) + 1), - torch.zeros([1, torch.max(sequences) + 1]), - ], - )[sequences] - - reference_sequence, sequences = sequences[0].unsqueeze(0), sequences[1:] - all_sequences.append(sequences) - - sizes = torch.tensor([len(seq) for seq in sequences]) - all_sizes.append(sizes) - - # pad alignments - alignments = torch.nested.to_padded_tensor( - torch.nested.nested_tensor(data["aln"]), - 0.0, - ) - - alignments = torch.concatenate( - [ - torch.eye(torch.max(alignments) + 1), - torch.zeros([1, torch.max(alignments) + 1]), - ], - )[alignments] - - _, alignments = alignments[0], alignments[1:] # ignore first alignment - all_alignments.append(alignments) - - matrices = make_similarity_matrices( - sequences, reference_sequence - ) # TODO (Edith): make matrices - all_matrices.append(matrices) - - self.sequences = torch.stack(all_sequences, dim=1) - self.alignments = torch.stack(all_alignments, dim=1) - self.sizes = torch.stack(all_sizes, dim=1) - self.matrices = torch.stack(all_matrices, dim=1) - - self.transform = transform - - self.target_transform = target_transform - - def __len__(self): - return self.sequences.size(0) - - def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: - inputs = self.matrices[index], self.sizes[index] - - if self.transform: - inputs = self.transform(*inputs) - - target = self.alignments[index] - - if self.target_transform: - target = self.target_transform(target) - - return inputs, target diff --git a/src/beignet/datasets/_smurf_dataset.py b/src/beignet/datasets/_smurf_dataset.py new file mode 100644 index 0000000000..2ace5329be --- /dev/null +++ b/src/beignet/datasets/_smurf_dataset.py @@ -0,0 +1,107 @@ +from pathlib import Path +from typing import Callable + +import numpy +import pooch +import torch +from torch import Tensor +from torch.utils.data import Dataset + +from ._smurf_dataset_constants import FAMILIES_TEST, FAMILIES_TRAIN, NUM_SEQUENCES_TEST, NUM_SEQUENCES_TRAIN + +class SMURFDataset(Dataset): + def __init__( + self, + root: str | Path, + *, + download: bool = False, + train: bool = True, + transform: Callable | None = None, + target_transform: Callable | None = None, + ): + if isinstance(root, str): + root = Path(root) + + name = self.__class__.__name__ + + if download: + pooch.retrieve( + "https://files.ipd.uw.edu/krypton/data_unalign.npz", + fname=f"{name}.npz", + known_hash="9cc22e381619b66fc353c079221fd02450705d4e3ee23e4e23a052b6e70a95ec", + path=root / name, + ) + + self.all_data = numpy.load(root / name / f"{name}.npz", + allow_pickle=True, mmap_mode="r" + ) + + if train: + families = FAMILIES_TRAIN + num_sequences = NUM_SEQUENCES_TRAIN + else: + families = FAMILIES_TEST + num_sequences = NUM_SEQUENCES_TEST + + self.all_sequences = torch.zeros([num_sequences, 583]) + self.all_references = torch.zeros([num_sequences, 583]) + self.all_alignments = torch.zeros([num_sequences, 583]) + self.all_sizes = torch.zeros([num_sequences, 1]) + + idx = 0 + + for family in families: + data = self.all_data[family].tolist() + + # sequences + sequences = torch.nested.to_padded_tensor( + torch.nested.nested_tensor(data["ms"]), + 0.0 + ) + reference_sequence, sequences = sequences[0], sequences[1:] + + chunk = torch.zeros([sequences.shape[0], 583]) + chunk[:, :sequences.shape[1]] = sequences + self.all_sequences[idx:idx+sequences.shape[0], :] = chunk + + chunk = torch.zeros([sequences.shape[0], 583]) + chunk[:, :sequences.shape[1]] = reference_sequence.repeat( + (sequences.shape[0], 1) + ) + self.all_references[idx:idx+sequences.shape[0], :] = chunk + + # alignments + alignments = torch.nested.to_padded_tensor( + torch.nested.nested_tensor(data["aln"]), + 0.0 + ) + _, alignments = alignments[0], alignments[1:] # discard the first alignment + + chunk = torch.zeros([alignments.shape[0], 583]) + chunk[:, :sequences.shape[1]] = alignments + self.all_alignments[idx:idx+sequences.shape[0], :] = chunk + + # sizes + self.all_sizes[idx:idx+sequences.shape[0], :] = torch.tensor([len(seq) for seq in sequences]).unsqueeze(1) # noqa: E501 + + idx += sequences.shape[0] + + self.transform = transform + + self.target_transform = target_transform + + def __len__(self): + return self.all_sequences.size(0) + + def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: + inputs = self.all_sequences[index], self.all_references[index], self.all_sizes[index] # noqa: E501 + + if self.transform: + inputs = self.transform(*inputs) + + target = self.all_alignments[index] + + if self.target_transform: + target = self.target_transform(target) + + return inputs, target diff --git a/src/beignet/datasets/_smurf_dataset_constants.py b/src/beignet/datasets/_smurf_dataset_constants.py new file mode 100644 index 0000000000..914974a702 --- /dev/null +++ b/src/beignet/datasets/_smurf_dataset_constants.py @@ -0,0 +1,392 @@ +FAMILIES_TEST = [ + '3F6GA', + '1F86A', + '3F9XA', + '2P3WA', + '3FDXA', + '5FH7A', + '2FIUA', + '1FM0D', + '1FM0E', + '3KZPA', + '3FPNB', + '4FVGA', + '1FXLA', + '4G0XA', + '3G13A', + '1G2RA', + '3G2EA', + '3NO0A', + '4G4KA', + '1G6XA', + '3G9KF', + '1GCIA', + '3U2UA', + '4GJZA', + '2GMYA', + '1GPRA', + '3GPKA', + '1N62C', + '4GWBA', + '3H0NA', + '3LYHA', + '4HYLA', + '3H5JA', + '3H7OA', + '4H8EA', + '1H99A', + '2HA8A', + '2HBAA', + '2HBWA', + '2HC8A', + '4XPCA', + '3HMCA', + '5HMLA', + '3HRLA', + '3HY3A', + '2HZQA', + '2JFRA', + '5I32A', + '2I6HA', + '3IBZA', + '2ICUA', + '1IIBA', + '2IIHA', + '5IJAA', + '3IP0A', + '2NWRA', + '1IQ4A', + '2IQYA', + '3IS6A', + '3ISRA', + '3IT4A', + '3IT4B', + '3ITQA', + '5ITQA', + '3IUGA', + '2IXDA', + '3OCMA', + '2IZ6A', + '3PVEA', + '1JKEA', + '1JL1A', + '2JLIA', + '3JXGA', + '1K4IA', + '4LFLA', + '2OMLA', + '3K8UA', + '1M93B', + '1MPGA', + '3KEWA', + '1KHYA', + '1KNMA', + '2NUHA', + '3L00A', + '3L51A', + '3L60A', + '1LFPA', + '3LF9A', + '1LOPA', + '3LQBA', + '4LQ4A', + '4LWRA', + '1LYQA', + '4M0NA', + '3M7AA', + '4ME3A', + '2MHRA', + '3MHXA', + '3MMHA', + '4MU3A', + '1MVLA', + '3MVUA', + '1X1OA', + '4NBXA', + '3NFDA', + '3NO4A', + '2NRKA', + '2NRRA', + '1NS5A', + '4NTKA', + '3NUAA', + '1NZ0A', + '2O70A', + '2OFKA', + '2OLMA', + '2OMKA', + '4ONMA', + '3PN3A', + '1WDJA', + '2OYAA', + '1OZ9A', + '1TQ5A', + '2PFRA', + '4PGRA', + '2PLIA', + '3PO8A', + '3POJA', + '4PUIA', + '3PYWA', + '3Q46A', + '3Q64A', + '2Q7SA', + '4QDNA', + '2QF4A', + '1QG8A', + '2QIFA', + '2QIPA', + '2VTCA', + '2QQ4A', + '1R5LA', + '1YB0A', + '1RSSA', + '1RV9A', + '4RWUA', + '1S7IA', + '1SEIA', + '1SJ1A', + '1SUMB', + '1TIFA', + '1TIGA', + '1TQGA', + '3UBYA', + '1UCDA', + '4UC1A', + '1UEBA', + '1UI0A', + '1USMA', + '1V6TA', + '2VE8A', + '1VGJA', + '1VMHA', + '2VXNA', + '1VZYA', + '4WK7A', + '1W2WA', + '1W2WB', + '2W6KA', + '4W9ZA', + '4WEEA', + '1WJXA', + '1WNYA', + '2WNPF', + '4WPKA', + '2WQKA', + '1X9UA', + '3WSGA', + '1WUBA', + '1WURA', + '2X8XX', + '4X84A', + '4X9JA', + '2XOVA', + '2XTYA', + '1Y6ZA', + '2Y71A', + '1YARH', + '1YD0A', + '2YN5A', + '1Z0WA', + '4YQDA', + '1ZAVA', + '3ZJAA', + '2ZPTX' +] + +NUM_SEQUENCES_TEST = 719_168 + +FAMILIES_TRAIN = [ + '3A0YA', + '4ACIA', + '3AH7A', + '5A62A', + '2A4VA', + '1A3AA', + '3A35A', + '5A35A', + '3GM5A', + '2A67A', + '3A6SA', + '1NNHA', + '4A7UA', + '4A7WA', + '5ECCA', + '5A89A', + '5C0PA', + '2ABWA', + '2A9SA', + '4JS8A', + '3AABA', + '3AAYA', + '4C5KA', + '4ABLA', + '3ACXA', + '1AE9A', + '4AFFA', + '4AFHA', + '3AGYA', + '1M2KA', + '4AIVA', + '4AIWA', + '1AKOA', + '3AK8A', + '3AKBA', + '3ALUA', + '2AMHA', + '2AN1A', + '2ANRA', + '2ANXA', + '4WTPA', + '2APJA', + '2FBNA', + '4APXB', + '2AQ6A', + '1ATZA', + '4ATEA', + '4AVRA', + '3AWUA', + '4AY0A', + '2B0AA', + '1EUWA', + '4LSCA', + '5B3PA', + '2B5GA', + '2HQSC', + '3B8BA', + '4B8EA', + '2B94A', + '1BD8A', + '3BEDA', + '3BEMA', + '3SY1A', + '2BFWA', + '5C1EA', + '4BH5A', + '2C2IA', + '2BK8A', + '2BKMA', + '2BKXA', + '4LXQA', + '3BM7A', + '2BOUA', + '3BP3A', + '3BPKA', + '5C90A', + '3KG9A', + '3BR8A', + '3BT5A', + '1BUOA', + '2BV5A', + '3BWUD', + '3BWUF', + '1I4JA', + '1BXYA', + '1BYRA', + '3BY8A', + '5BY4A', + '2BZ1A', + '4N0KA', + '3C1QA', + '4C24A', + '3C37A', + '3C4BA', + '2C5QA', + '4C6AA', + '4C6SA', + '2C71A', + '2C8MA', + '2C92A', + '1K7KA', + '5CAJA', + '3CCDA', + '3CCGA', + '5CEGB', + '1WPNA', + '1CFBA', + '1CHDA', + '2GGCA', + '3CH0A', + '3LULA', + '3MN2A', + '4EWFA', + '3NREA', + '3CI3A', + '4NNOA', + '3CNVA', + '2EGZA', + '1COJA', + '1COZA', + '3CQ1A', + '5CQXA', + '1KQPA', + '1CTFA', + '1CUKA', + '2CVEA', + '1CXQA', + '3LF5A', + '3CXKA', + '5CX7A', + '3CZXA', + '1D0QA', + '3D01A', + '3D03A', + '4D05A', + '4D74A', + '4DBFA', + '2D4XA', + '3LTJA', + '4DAMA', + '3DBOB', + '3DD6A', + '1H72C', + '4DE9A', + '1DFUP', + '3DFGA', + '1DJ0A', + '2DQWA', + '3FDJA', + '2DTJA', + '4DT4A', + '4DUNA', + '2DXAA', + '4XTVA', + '2DYIA', + '2DYJA', + '2E0NA', + '2E11A', + '1K7JA', + '4E3YA', + '3UF6A', + '1E58A', + '2E5YA', + '2GUIA', + '2EBJA', + '3EERA', + '4HOIA', + '2EGVA', + '3OHEA', + '3EJKA', + '1EKEA', + '1EKJA', + '4ONWA', + '4EOJB', + '4EQPA', + '3ERBA', + '3ERSX', + '2G2CA', + '4ES1A', + '2EW0A', + '4K08A', + '3F0DA', + '4F01A', + '2F1FA', + '2F23A', + '5F3MA', + '3F42A', + '2F5GA', + '3F5VA', + '4F55A' +] + +NUM_SEQUENCES_TRAIN = 950_762 \ No newline at end of file