Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Edith Lee committed May 31, 2024
1 parent 6e0887c commit 4aaff60
Show file tree
Hide file tree
Showing 2 changed files with 411 additions and 401 deletions.
44 changes: 27 additions & 17 deletions src/beignet/datasets/_smurf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from torch import Tensor
from torch.utils.data import Dataset

from ._smurf_dataset_constants import FAMILIES_TEST, FAMILIES_TRAIN, NUM_SEQUENCES_TEST, NUM_SEQUENCES_TRAIN
from ._smurf_dataset_constants import (
FAMILIES_TEST,
FAMILIES_TRAIN,
NUM_SEQUENCES_TEST,
NUM_SEQUENCES_TRAIN,
)


class SMURFDataset(Dataset):
def __init__(
Expand All @@ -32,9 +38,9 @@ def __init__(
path=root / name,
)

self.all_data = numpy.load(root / name / f"{name}.npz",
allow_pickle=True, mmap_mode="r"
)
self.all_data = numpy.load(
root / name / f"{name}.npz", allow_pickle=True, mmap_mode="r"
)

if train:
families = FAMILIES_TRAIN
Expand All @@ -55,34 +61,34 @@ def __init__(

# sequences
sequences = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(data["ms"]),
0.0
torch.nested.nested_tensor(data["ms"]), 0.0
)
reference_sequence, sequences = sequences[0], sequences[1:]

chunk = torch.zeros([sequences.shape[0], 583])
chunk[:, :sequences.shape[1]] = sequences
self.all_sequences[idx:idx+sequences.shape[0], :] = chunk
chunk[:, : sequences.shape[1]] = sequences
self.all_sequences[idx : idx + sequences.shape[0], :] = chunk

chunk = torch.zeros([sequences.shape[0], 583])
chunk[:, :sequences.shape[1]] = reference_sequence.repeat(
chunk[:, : sequences.shape[1]] = reference_sequence.repeat(
(sequences.shape[0], 1)
)
self.all_references[idx:idx+sequences.shape[0], :] = chunk
self.all_references[idx : idx + sequences.shape[0], :] = chunk

# alignments
alignments = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(data["aln"]),
0.0
torch.nested.nested_tensor(data["aln"]), 0.0
)
_, alignments = alignments[0], alignments[1:] # discard the first alignment
_, alignments = alignments[0], alignments[1:] # discard the first alignment

chunk = torch.zeros([alignments.shape[0], 583])
chunk[:, :sequences.shape[1]] = alignments
self.all_alignments[idx:idx+sequences.shape[0], :] = chunk
chunk[:, : sequences.shape[1]] = alignments
self.all_alignments[idx : idx + sequences.shape[0], :] = chunk

# sizes
self.all_sizes[idx:idx+sequences.shape[0], :] = torch.tensor([len(seq) for seq in sequences]).unsqueeze(1) # noqa: E501
self.all_sizes[idx : idx + sequences.shape[0], :] = torch.tensor(
[len(seq) for seq in sequences]
).unsqueeze(1) # noqa: E501

idx += sequences.shape[0]

Expand All @@ -94,7 +100,11 @@ def __len__(self):
return self.all_sequences.size(0)

def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
inputs = self.all_sequences[index], self.all_references[index], self.all_sizes[index] # noqa: E501
inputs = (
self.all_sequences[index],
self.all_references[index],
self.all_sizes[index],
) # noqa: E501

if self.transform:
inputs = self.transform(*inputs)
Expand Down
Loading

0 comments on commit 4aaff60

Please sign in to comment.